Paper

MARBLE:扩散模型强化学习中的多维奖励平衡新范式

5 min read ·

为什么扩散模型的 RL 对齐如此棘手?

近年来,扩散模型在图像和视频生成领域取得了令人瞩目的进展。然而,当我们试图通过强化学习(RL)来微调扩散模型,使其生成结果更好地符合人类偏好时,一个核心难题浮出水面:如何同时优化多个奖励维度?

举一个直观的例子。假设我们要训练一个文本到图像的生成模型,希望生成的图片同时满足以下要求:

传统做法通常是将这些奖励加权求和:

R_total = w_1 * R_aesthetic + w_2 * R_alignment + w_3 * R_quality

这种看似简洁的方法实际上暗藏危机。多个奖励维度之间往往存在复杂的耦合关系——追求更高的美学评分可能会牺牲图文一致性,而过度优化清晰度可能导致生成结果缺乏多样性。加权求和将这些高维信号压缩为单一标量,不可避免地造成信息丢失和优化冲突。

现有的替代方案包括为每个奖励训练独立的专家模型再进行蒸馏,或者手动设计阶段性训练调度。但这些方法要么计算成本高昂,要么需要大量人工调参,且难以泛化到新的奖励组合。

MARBLE 的核心思路:在梯度空间中保持多维信号

MARBLE(Multi-Aspect Reward Balance)的核心洞察是:与其在奖励空间中进行妥协,不如在梯度空间中保持各维度的独立性。

具体来说,MARBLE 不将多个奖励合并为一个标量,而是在每次训练迭代中分别计算每个奖励维度对应的梯度分量,然后通过一种平衡机制将这些梯度分量组合为最终的更新方向。

这个思路的关键优势在于:每个奖励维度的梯度信号在组合过程中得以保留,而不是被提前压缩。这使得模型能够在不同奖励维度之间找到更好的帕累托最优解,而不是被迫在加权求和的单一方向上优化。

梯度分解

给定 K 个奖励函数 R_1, R_2, ..., R_K,MARBLE 首先分别计算每个奖励对应的策略梯度:

g_k = ∇_θ E[R_k(x)]    (k = 1, 2, ..., K)

其中 θ 是扩散模型的参数,x 是生成的样本。这一步确保了每个奖励维度的梯度方向被独立保留。

动态平衡

接下来,MARBLE 引入一个平衡机制来组合这些梯度分量。核心思想是:给予当前优化不足的奖励维度更高的梯度贡献权重。

g_total = Σ (α_k * g_k)

其中 α_k 是动态调整的权重系数。MARBLE 通过追踪每个奖励维度在近期训练中的改进幅度来计算 α_k——如果某个奖励维度的改进停滞,其对应的 α_k 会被上调,从而在梯度更新中获得更多话语权。

伪代码实现

以下是 MARBLE 核心逻辑的简化伪代码,帮助理解其工作流程:

class MARBLE:
    def __init__(self, reward_functions, ema_decay=0.99):
        self.rewards = reward_functions
        self.K = len(reward_functions)
        self.ema_decay = ema_decay
        # 每个奖励维度的历史改进幅度(EMA)
        self.reward_improvement = [0.0] * self.K

    def compute_balanced_gradient(self, model, samples, rewards):
        # Step 1: 分别计算每个奖励的梯度
        gradients = []
        for k in range(self.K):
            g_k = compute_policy_gradient(model, samples, rewards[k])
            gradients.append(g_k)

        # Step 2: 计算动态平衡权重
        alphas = self._compute_alphas(rewards)

        # Step 3: 加权组合梯度
        total_gradient = sum(a * g for a, g in zip(alphas, gradients))

        return total_gradient

    def _compute_alphas(self, rewards):
        """改进不足的维度获得更高权重"""
        alphas = []
        for k in range(self.K):
            # 用 EMA 追踪奖励改进幅度
            improvement = rewards[k].mean() - self.reward_improvement[k]
            self.reward_improvement[k] = (
                self.ema_decay * self.reward_improvement[k]
                + (1 - self.ema_decay) * rewards[k].mean()
            )
            # 改进越少,权重越高
            alphas.append(1.0 / (improvement + 1e-6))

        # 归一化
        total = sum(alphas)
        return [a / total for a in alphas]

💡 提示:上述伪代码是对 MARBLE 核心机制的简化展示。实际论文中的实现涉及更精细的梯度投影和约束处理,以确保训练稳定性。

实验结果:多维奖励的帕累托改进

论文在文本到图像生成任务上进行了系统性实验,将 MARBLE 与以下基线方法进行对比:

方法策略计算成本
Weighted Sum奖励加权求和
Expert Iteration独立训练专家模型再蒸馏
Scheduled阶段性切换优化目标
MARBLE梯度空间多维平衡低-中

实验结果揭示了几个重要发现:

第一,MARBLE 有效避免了奖励坍缩。 在传统加权求和方法中,训练后期往往出现某个奖励维度主导优化方向的现象,导致其他维度的性能退化。MARBLE 通过梯度空间的独立保持,显著缓解了这一问题。

第二,MARBLE 在综合评分上取得显著提升。 在同时考虑美学质量、图文一致性和感知质量的综合评估中,MARBLE 相比加权求和基线提升了约 15-25%,且在每个单独维度上都保持了竞争力。

第三,训练稳定性大幅提升。 MARBLE 的训练曲线更加平滑,奖励方差更小,这意味着它能够更可靠地找到多维奖励之间的平衡点,而不是在不同维度之间剧烈震荡。

消融实验的关键发现

论文还进行了细致的消融实验,验证了 MARBLE 各组件的贡献:

工程意义与实践建议

MARBLE 对于从事扩散模型 RL 微调的工程师和研究者具有直接的实用价值。

何时考虑使用 MARBLE

如果你的场景涉及以下情况,MARBLE 值得优先考虑:

  1. 多维奖励存在冲突:当不同奖励维度之间存在明显的权衡关系时
  2. 加权求和调参困难:当你发现无论如何调整权重系数,总有一个维度表现不佳时
  3. 训练不稳定:当 RL 微调过程中奖励波动剧烈、难以收敛时

集成到现有流程

MARBLE 的设计遵循即插即用原则。集成到现有训练流程的主要步骤:

# 原有的加权求和方式
# reward = w1 * r1 + w2 * r2 + w3 * r3
# loss = -reward

# 替换为 MARBLE
marble = MARBLE(
    reward_functions=[aesthetic_scorer, alignment_scorer, quality_scorer],
    ema_decay=0.99
)

for batch in dataloader:
    samples = diffusion_model.sample(batch.prompts)
    rewards = [fn(samples) for fn in [r1_scorer, r2_scorer, r3_scorer]]

    balanced_grad = marble.compute_balanced_gradient(model, samples, rewards)
    model.backward(balanced_grad)
    optimizer.step()

潜在的局限性

尽管 MARBLE 展示了令人鼓舞的结果,仍有一些值得注意的局限:

总结

MARBLE 为扩散模型 RL 微调中的多维奖励优化提供了一个优雅且实用的解决方案。其核心贡献在于将问题从奖励空间转移到梯度空间,通过保持各维度的独立梯度信号来避免信息丢失和优化冲突。对于需要同时对齐多个目标的生成模型训练场景,MARBLE 是一个值得纳入工具箱的方法。

论文代码和预训练权重预计将在近期开源,届时社区可以进一步验证和拓展这一方法在不同任务和模态上的适用性。

Frequently asked questions

MARBLE 与传统的奖励加权求和方法有什么本质区别?
传统方法将多个奖励加权求和为单一标量,容易导致奖励维度间的相互干扰和坍缩。MARBLE 在梯度空间中保持每个奖励维度的独立监督信号,通过梯度分解和平衡机制实现多维奖励的同步改进,避免了信息丢失。
MARBLE 框架适用于哪些应用场景?
MARBLE 主要适用于需要多维对齐的扩散模型场景,包括高质量图像生成(同时优化美学评分、图文一致性、清晰度)、视频生成(运动流畅性、语义一致性、视觉质量)以及任何需要平衡多个奖励信号的生成任务。
MARBLE 的梯度平衡机制是如何实现的?
MARBLE 将总训练梯度分解为各奖励维度的独立分量,通过动态调整各分量的贡献权重,确保所有奖励维度在训练过程中都能获得足够的优化信号。核心思想是避免某个强奖励维度主导梯度方向,从而抑制其他维度的学习。
MARBLE 在实验中相比基线方法提升有多大?
论文在文本到图像生成任务上进行了广泛实验,MARBLE 在多维奖励综合评分上相比传统加权求和方法提升了约 15-25%,同时在单个奖励维度上也保持了竞争力,有效避免了奖励间的跷跷板效应。
MARBLE 的计算开销如何?能否直接集成到现有训练流程?
MARBLE 的额外计算开销主要来自梯度分解步骤,相比基础 RL 微调约增加 10-20% 的训练时间。框架设计为即插即用模块,可直接替换现有的奖励聚合策略,无需修改扩散模型架构或 RL 算法本身。