为什么扩散模型的 RL 对齐如此棘手?
近年来,扩散模型在图像和视频生成领域取得了令人瞩目的进展。然而,当我们试图通过强化学习(RL)来微调扩散模型,使其生成结果更好地符合人类偏好时,一个核心难题浮出水面:如何同时优化多个奖励维度?
举一个直观的例子。假设我们要训练一个文本到图像的生成模型,希望生成的图片同时满足以下要求:
- 美学质量高(美学评分奖励)
- 与文本描述一致(图文匹配奖励)
- 细节清晰、无伪影(感知质量奖励)
传统做法通常是将这些奖励加权求和:
R_total = w_1 * R_aesthetic + w_2 * R_alignment + w_3 * R_quality
这种看似简洁的方法实际上暗藏危机。多个奖励维度之间往往存在复杂的耦合关系——追求更高的美学评分可能会牺牲图文一致性,而过度优化清晰度可能导致生成结果缺乏多样性。加权求和将这些高维信号压缩为单一标量,不可避免地造成信息丢失和优化冲突。
现有的替代方案包括为每个奖励训练独立的专家模型再进行蒸馏,或者手动设计阶段性训练调度。但这些方法要么计算成本高昂,要么需要大量人工调参,且难以泛化到新的奖励组合。
MARBLE 的核心思路:在梯度空间中保持多维信号
MARBLE(Multi-Aspect Reward Balance)的核心洞察是:与其在奖励空间中进行妥协,不如在梯度空间中保持各维度的独立性。
具体来说,MARBLE 不将多个奖励合并为一个标量,而是在每次训练迭代中分别计算每个奖励维度对应的梯度分量,然后通过一种平衡机制将这些梯度分量组合为最终的更新方向。
这个思路的关键优势在于:每个奖励维度的梯度信号在组合过程中得以保留,而不是被提前压缩。这使得模型能够在不同奖励维度之间找到更好的帕累托最优解,而不是被迫在加权求和的单一方向上优化。
梯度分解
给定 K 个奖励函数 R_1, R_2, ..., R_K,MARBLE 首先分别计算每个奖励对应的策略梯度:
g_k = ∇_θ E[R_k(x)] (k = 1, 2, ..., K)
其中 θ 是扩散模型的参数,x 是生成的样本。这一步确保了每个奖励维度的梯度方向被独立保留。
动态平衡
接下来,MARBLE 引入一个平衡机制来组合这些梯度分量。核心思想是:给予当前优化不足的奖励维度更高的梯度贡献权重。
g_total = Σ (α_k * g_k)
其中 α_k 是动态调整的权重系数。MARBLE 通过追踪每个奖励维度在近期训练中的改进幅度来计算 α_k——如果某个奖励维度的改进停滞,其对应的 α_k 会被上调,从而在梯度更新中获得更多话语权。
伪代码实现
以下是 MARBLE 核心逻辑的简化伪代码,帮助理解其工作流程:
class MARBLE:
def __init__(self, reward_functions, ema_decay=0.99):
self.rewards = reward_functions
self.K = len(reward_functions)
self.ema_decay = ema_decay
# 每个奖励维度的历史改进幅度(EMA)
self.reward_improvement = [0.0] * self.K
def compute_balanced_gradient(self, model, samples, rewards):
# Step 1: 分别计算每个奖励的梯度
gradients = []
for k in range(self.K):
g_k = compute_policy_gradient(model, samples, rewards[k])
gradients.append(g_k)
# Step 2: 计算动态平衡权重
alphas = self._compute_alphas(rewards)
# Step 3: 加权组合梯度
total_gradient = sum(a * g for a, g in zip(alphas, gradients))
return total_gradient
def _compute_alphas(self, rewards):
"""改进不足的维度获得更高权重"""
alphas = []
for k in range(self.K):
# 用 EMA 追踪奖励改进幅度
improvement = rewards[k].mean() - self.reward_improvement[k]
self.reward_improvement[k] = (
self.ema_decay * self.reward_improvement[k]
+ (1 - self.ema_decay) * rewards[k].mean()
)
# 改进越少,权重越高
alphas.append(1.0 / (improvement + 1e-6))
# 归一化
total = sum(alphas)
return [a / total for a in alphas]
💡 提示:上述伪代码是对 MARBLE 核心机制的简化展示。实际论文中的实现涉及更精细的梯度投影和约束处理,以确保训练稳定性。
实验结果:多维奖励的帕累托改进
论文在文本到图像生成任务上进行了系统性实验,将 MARBLE 与以下基线方法进行对比:
| 方法 | 策略 | 计算成本 |
|---|---|---|
| Weighted Sum | 奖励加权求和 | 低 |
| Expert Iteration | 独立训练专家模型再蒸馏 | 高 |
| Scheduled | 阶段性切换优化目标 | 中 |
| MARBLE | 梯度空间多维平衡 | 低-中 |
实验结果揭示了几个重要发现:
第一,MARBLE 有效避免了奖励坍缩。 在传统加权求和方法中,训练后期往往出现某个奖励维度主导优化方向的现象,导致其他维度的性能退化。MARBLE 通过梯度空间的独立保持,显著缓解了这一问题。
第二,MARBLE 在综合评分上取得显著提升。 在同时考虑美学质量、图文一致性和感知质量的综合评估中,MARBLE 相比加权求和基线提升了约 15-25%,且在每个单独维度上都保持了竞争力。
第三,训练稳定性大幅提升。 MARBLE 的训练曲线更加平滑,奖励方差更小,这意味着它能够更可靠地找到多维奖励之间的平衡点,而不是在不同维度之间剧烈震荡。
消融实验的关键发现
论文还进行了细致的消融实验,验证了 MARBLE 各组件的贡献:
- 去掉动态平衡权重(使用固定均等权重):性能下降约 8%,说明动态调整机制对于适应不同训练阶段的需求至关重要
- 去掉梯度分解(使用加权奖励的单一梯度):退化为传统方法,性能下降约 18%
- 调整 EMA 衰减系数:在 0.95-0.99 范围内表现稳定,说明方法对超参数不敏感
工程意义与实践建议
MARBLE 对于从事扩散模型 RL 微调的工程师和研究者具有直接的实用价值。
何时考虑使用 MARBLE
如果你的场景涉及以下情况,MARBLE 值得优先考虑:
- 多维奖励存在冲突:当不同奖励维度之间存在明显的权衡关系时
- 加权求和调参困难:当你发现无论如何调整权重系数,总有一个维度表现不佳时
- 训练不稳定:当 RL 微调过程中奖励波动剧烈、难以收敛时
集成到现有流程
MARBLE 的设计遵循即插即用原则。集成到现有训练流程的主要步骤:
# 原有的加权求和方式
# reward = w1 * r1 + w2 * r2 + w3 * r3
# loss = -reward
# 替换为 MARBLE
marble = MARBLE(
reward_functions=[aesthetic_scorer, alignment_scorer, quality_scorer],
ema_decay=0.99
)
for batch in dataloader:
samples = diffusion_model.sample(batch.prompts)
rewards = [fn(samples) for fn in [r1_scorer, r2_scorer, r3_scorer]]
balanced_grad = marble.compute_balanced_gradient(model, samples, rewards)
model.backward(balanced_grad)
optimizer.step()
潜在的局限性
尽管 MARBLE 展示了令人鼓舞的结果,仍有一些值得注意的局限:
- 奖励函数质量依赖:MARBLE 优化的是给定奖励函数的组合,如果某个奖励函数本身定义不佳,MARBLE 也无法弥补
- 梯度冲突处理:当多个奖励的梯度方向完全相反时,平衡机制的效果会打折扣
- 扩展到大规模奖励组合:论文主要验证了 2-4 个奖励维度的场景,更多维度的组合效果有待进一步验证
总结
MARBLE 为扩散模型 RL 微调中的多维奖励优化提供了一个优雅且实用的解决方案。其核心贡献在于将问题从奖励空间转移到梯度空间,通过保持各维度的独立梯度信号来避免信息丢失和优化冲突。对于需要同时对齐多个目标的生成模型训练场景,MARBLE 是一个值得纳入工具箱的方法。
论文代码和预训练权重预计将在近期开源,届时社区可以进一步验证和拓展这一方法在不同任务和模态上的适用性。