1. 引述
PPO 网络由于有 Value 网络的存在,结构相对复杂,并且不稳定。因为 Value 网络和大语言模型共享参数。
GRPO 相对于 PPO 来说,最大的创新之处在于其不需要 Value 网络来对大模型生成的回复打分。但是,如果没有 Value 网络的打分,那么怎么知道大模型生成的回复是好是坏呢?
2. GRPO
2.1 组(Group)的概念
PPO 对一个问题输出采样一个回复输出,而 GRPO 则是对一个问题输入采样多个回复输出。随后,使用奖励模型对这多个回复依次打分。
这些得分肯定有高有低,就把这些得分分成高低两组。于是,得分低的那一组回复就需要抑制输出,得分高的那一组就需要鼓励输出。
而分组的方法是归一化:
公式中的 代表奖励。这里所有的奖励,都是对同一个问题的多个回复采样得到的。
很显然,最后会得到一个正态分布。而得分高的一组(绿色)就是好的回复;得分低的一组(红色)就是不好的回复。
通过组的概念,GRPO 成功绕过了 PPO 的 Value 网络的设计。
2.2 优化目标
回忆一下 PPO 的优化目标:(ch.9)
这个目标函数中, 是优势函数值,也就是当前的真实 Q 值减去预测 Q 值(或 V 值)。这里的预测值是由 Value 网络提供。
但是由于 GRPO 的 Value 网络被替换成了两组得分 ,于是就把目标中的优势函数值
替换成分组得分
:
当然,别忘了加上 KL 散度项,这一项的目的和公式和之前 PPO 时一样:
3. DAPO
3.1 介绍
在 GRPO 推出后仅仅几个月的时间,清华大学就推出了 GRPO 的 2.0 版本,叫做 DAPO。其优化函数如下:
3.2 Clip-Higher
在 GRPO 和 PPO 方法在微调模型的时候,由于有这一项:
这使得模型虽然能稳定训练,但同时也限制了模型的探索能力。因为限制了模型步伐不能太大,导致模型会陷入一种名为熵坍缩的情况,也就是模型过于自信,对同一个问题不会再输出其他略有不同的解释。
比如说,当某个 token(或动作)在旧策略中概率极低时,比如只有 1%,即使它在当前梯度中被评为正向(优势值 A>0),其最大增长也只能从 1% 提升到 ,增幅非常有限,难以显著提高探索概率。
因此,DAPO 通过调整剪切的参数来解决这个问题:
原本 GRPO 设置的超参数 是 0.2,DAPO 仅仅调整上界
到 0.28,下界
维持 0.2 不变。
3.3 动态采样
GRPO 对一个问题给出多个回答,而奖励模型对这些回答会打分,通常是回答正确打 1 分,回答错误打 0 分。然而,当模型的效果越来越好,会出现模型输出的多个回答全对的情况,此时 GRPO 的 值就为 0,使得梯度消失。
DAPO 通过动态采样,如果一个问题的所有回复都正确,那么就换一个问题,或者直到采样到错误的回复为止。
这个公式代表正确回答的数量必须大于 0 同时小于所有回复的数量(G)
3.4 Token 级损失
在 GRPO 中,所有生成的回复在做梯度下降时,权重相等(因为没有额外赋予权重)。这就使得高质量的长句子回答没有给足奖励,低质量的长句回答没有给足惩罚
举个例子,大模型生成了两条回复,一条是长句子,有 100 个 token;一条是短句子,有 10 个 token。这两个句子的质量都很高,于是模型就要从这两个句子中学习。
在 GRPO 的方式下,由于是 Sample-Level 的梯度,不考虑 token,也就是每个句子分配相同的权重。但是长句子应该给更多的奖励,短句子应该给更少的奖励。
同样的,如果 100 个 token 的句子是低质量的,那么给的惩罚也得比 10 个 token 的低质量句子大。
为了解决这个问题,DAPO 采用 Token-Level,也就是说是每个 token 一个梯度。也就是说,对于 100 个 token 的回复对比 10 个 token 的回复,在梯度下降的过程中,长句子就应该占有更多的梯度:
这里最左侧的这个符号:
就是 DAPO 和 GRPO 不同的一点。
3.5 长输出惩罚
在 LLM 生成回复的时候,有时候回复过长,此时做法一般是截断过长部分。但是这样可能会导致模型本来输出正确,一截断之后就错误了。于是,奖励模型就告诉 LLM 你的答案是错的,但是实际上 LLM 是对的,只是答案不完整,但是 LLM 不知道,所以就会使得 LLM 往错误的方向训练。
为了解决这个问题,DAPO 直接让生成过长的回复的不参与模型的微调。同时,为模型输出长短通过设置奖励来做出限制:
稍微解释一下这个公式:
是模型输出长度上限,超过这个长度就直接给到最大惩罚;
是理想长度,低于这个长度不予惩罚;
是模型输出长度的一个类似 “缓存空间” 的变量。意思就是说,当模型输出的长度从 理想长度
开始逼近最大长度
时,给予从 0 到 -1 的线性惩罚
比如 ,
,那么模型输出 90 个 token 以内视作合理,一旦超过 90 个就开始给惩罚,比如 91 个 token 就给 -0.1 的惩罚。