Fast Inference from Transformers via Speculative Decoding
论文地址:https://arxiv.org/pdf/2211.17192
speculative sampling
为了从分布 p ( x ) p(x) p(x) 中采样,我们实际上是从分布 q ( x ) q(x) q(x) 中采样 x x x,如果 q ( x ) ≤ p ( x ) q(x) \leq p(x) q(x)≤p(x),则保留该样本;如果 q ( x ) > p ( x ) q(x) > p(x) q(x)>p(x),则以概率 1 − p ( x ) q ( x ) 1 - \frac{p(x)}{q(x)} 1−q(x)p(x) 拒绝该样本,并重新从调整后的分布 p ′ ( x ) = norm ( max ( 0 , p ( x ) − q ( x ) ) ) p'(x) = \text{norm}(\max(0, p(x)-q(x))) p′(x)=norm(max(0,p(x)−q(x))) 中采样。对于任何分布 p ( x ) p(x) p(x) 和 q ( x ) q(x) q(x),以及以此方式采样的 x x x,确实有 x ∼ p ( x ) x \sim p(x) x∼p(x)。
给定通过在条件前缀上运行 M q M_q Mq 获得的分布 q ( x ) q(x) q(x),我们可以采样一个标记 x 1 ∼ q ( x ) x_1 \sim q(x) x1∼q(x)。然后,我们通过在前缀上运行 M p M_p Mp 来计算分布 p ( x ) p(x) p(x),同时并行地推测性地计算下一个标记 x 2 x_2 x2 的分布,即在前缀上追加 x 1 x_1 x1 后运行 M p M_p Mp。一旦两项计算都完成,我们就按上述方式处理:如果 x 1 x_1 x1 被拒绝,我们丢弃 x 2 x_2 x2 的计算,并从调整后的分布中重新采样 x 1 x_1 x1;如果 x 1 x_1 x1 被接受,我们就保留两个标记。算法 1 将这一想法推广为一次采样 1 到 γ + 1 \gamma + 1 γ+1 个标记。
分析
有几个证明需要注意一下:
单次算法期望能生成的token
单次算法期望能生成的token数量服从几何分布,但是求和项是有限制的,这里推导下
接受率β的定义
设目标模型分布为p(x)
,草稿模型分布为q(x)
。草稿模型生成的单个token被目标模型接受的概率为:
β = ∑ x min ( q ( x ) , p ( x ) ) \beta = \sum_x \min\left(q(x), p(x)\right) β=x∑min(q(x),p(x))
- 拒绝率α的定义
α = 1 − β = 1 − ∑ x min ( p ( x ) , q ( x ) ) x \alpha = 1 - \beta = 1 - \sum_x \min(p(x), q(x)) x α=1−β=1−x∑min(p(x),q(x))x
假设每个token的接受事件独立且同分布(i.i.d.),草稿模型一次生成
K
个token:首次拒绝发生在位置
r
的概率为:P ( r ) = ( 1 − β ) β r − 1 ( 1 ≤ r ≤ K ) P(r) = (1-\beta) \beta^{r-1} \quad (1 \leq r \leq K) P(r)=(1−β)βr−1(1≤r≤K)
所有token均被接受 的概率为: β K \beta^K βK
综上期望能生成的token数量为:
γ = ∑ r = 1 K r ⋅ P ( r ) ⏟ 拒绝前生成的token + K ⋅ β K ⏟ 全接受时生成K个token \gamma = \underbrace{\sum_{r=1}^K r \cdot P(r)}_{\text{拒绝前生成的token}} + \underbrace{K \cdot \beta^K}_{\text{全接受时生成K个token}} γ=拒绝前生成的token r=1∑Kr⋅P(r)+全接受时生成K个token K⋅βK
代入 P ( r ) P(r) P(r) 后展开:
γ = ∑ r = 1 K r ⋅ ( 1 − β ) β r − 1 + K β K \gamma = \sum_{r=1}^K r \cdot (1-\beta) \beta^{r-1} + K \beta^K γ=r=1∑Kr⋅(1−β)βr−1+KβK
- 几何级数求和
几何级数求和公式为:
对 ∑ r = 1 K r β r − 1 \sum_{r=1}^K r \beta^{r-1} ∑r=1Krβr−1 求和处理:
- 令 S = ∑ r = 1 K β r − 1 S = \sum_{r=1}^K \beta^{r-1} S=∑r=1Kβr−1:
S = 1 + β + β 2 + ⋯ + β K − 1 = 1 − β K 1 − β S = 1 + \beta + \beta^2 + \cdots + \beta^{K-1} = \frac{1-\beta^K}{1-\beta} S=1+β+β2+⋯+βK−1=1−β1−βK
- 对 S S S 求导:
∑ r = 1 K r β r − 1 = d d β ( ∑ r = 0 K β r ) = d d β ( 1 − β K + 1 1 − β ) = 1 − ( K + 1 ) β K + K β K + 1 ( 1 − β ) 2 \sum_{r=1}^K r \beta^{r-1} = \frac{d}{d\beta} \left( \sum_{r=0}^K \beta^r \right) = \frac{d}{d\beta} \left( \frac{1-\beta^{K+1}}{1-\beta} \right) = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} ∑r=1Krβr−1=dβd(∑r=0Kβr)=dβd(1−β1−βK+1)=(1−β)21−(K+1)βK+KβK+1
- 代入γ表达式:
γ = ( 1 − β ) ⋅ 1 − ( K + 1 ) β K + K β K + 1 ( 1 − β ) 2 + K β K = 1 − ( K + 1 ) β K + K β K + 1 1 − β + K β K \gamma = (1-\beta) \cdot \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} + K\beta^K = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{1-\beta} + K\beta^K γ=(1−β)⋅(1−β)21−(K+1)βK+KβK+1+KβK=1−β1−(K+1)βK+KβK+1+KβK
- 化简:
γ = 1 − β K 1 − β \gamma = \frac{1 - \beta^K}{1-\beta} γ=1−β1−βK
物理意义:
- 当 K → ∞ K \to \infty K→∞时, γ → 1 1 − β = 1 α \gamma \to \frac{1}{1-\beta} = \frac{1}{\alpha} γ→1−β1=α1(理想无限长草稿)。
- 例如 β \beta β = 0.8` 时, γ max = 5 \gamma_{\text{max}} = 5 γmax=5,即平均每次生成5个token。
得证
Walltime的时间优化
定理 3.8:算法 1 在总运行时间上的预期改进因子为
‘ 1 − α γ + 1 ( 1 − α ) ( γ c + 1 ) ‘ `\frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)}` ‘(1−α)(γc+1)1−αγ+1‘
证明:
记运行目标模型 M p M_p Mp 单步的成本为 T T T。
算法 1 的单次运行成本为 T c γ + T Tc\gamma + T Tcγ+T(其中 c γ T c\gamma T cγT用于运行近似模型 M q M_q Mq γ \gamma γ 次, T T T 用于运行 M p M_p Mp 一次)。
根据单次算法期望能生成的token算法推导,单次运行平均生成 token 数量为 1 − α γ + 1 1 − α \dfrac{1 - \alpha^{\gamma + 1}}{1 - \alpha} 1−α1−αγ+1。
因此,使用算法 1 生成单个 token 的总体预期成本为:
( c γ + 1 ) ( 1 − α ) 1 − α γ + 1 T ‘ \frac{(c\gamma + 1)(1 - \alpha)}{1 - \alpha^{\gamma + 1}}T` 1−αγ+1(cγ+1)(1−α)T‘
由于标准解码算法生成单个 token 的成本为 T
,
比较可得上述改进因子。∎
(注:符号 “∎” 表示证明结束)
关键术语说明:
英文术语 | 中文翻译 | 符号 | 含义 |
---|---|---|---|
walltime | 总运行时间 | - | 算法从启动到结束的时钟时间 |
expected improvement factor | 预期改进因子 | - | 优化后时间开销的缩减比例 |
cost per step | 单步成本 | T T T | 目标模型 M p M_p Mp 推理一个 token 的时间 |
approximation model | 近似模型 | M q M_q Mq | 快速但低精度的草稿模型 |
tokens | 标记(Token) | - | 模型生成的基本文本单位 |
rejection rate | 拒绝率 | α \alpha α | 草稿模型 M q M_q Mq 的 token 被目标模型 M p M_p Mp 拒绝的概率 |
γ \gamma γ | 生成长度 | γ \gamma γ | 草稿模型单次运行的 token 生成数 |
cost ratio | 成本比 | c c c | M q M_q Mq 与 M p M_p Mp 的单步时间比值( 0 < c < 1 0 < c < 1 0<c<1) |
公式解析:
- 改进因子
1 − α γ + 1 ( 1 − α ) ( γ c + 1 ) \frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)} (1−α)(γc+1)1−αγ+1
- 分子 1 − α γ + 1 1 - \alpha^{\gamma+1} 1−αγ+1:草稿模型连续生成
\gamma
个 token 均未被拒绝的概率补偿 - 分母 ( 1 − α ) (1-\alpha) (1−α):单 token 接受率, γ c + 1 \gamma c + 1 γc+1:草稿+验证的总时间成本
该值 >1 时表示加速,值越大加速效果越显著
- 单 token 成本公式
( c γ + 1 ) ( 1 − α ) 1 − α γ + 1 T \frac{(c\gamma+1)(1-\alpha)}{1-\alpha^{\gamma+1}}T 1−αγ+1(cγ+1)(1−α)T
- 分子 ( c γ + 1 ) ( 1 − α ) T (c\gamma+1)(1-\alpha)T (cγ+1)(1−α)T:草稿生成+验证的实际计算量
- 分母 1 − α γ + 1 1-\alpha^{\gamma+1} 1−αγ+1:有效 token 产出的概率加权
操作数计算
操作数的计算量也是类似的,直接贴结论了
( 1 − α ) ( γ c ^ + γ + 1 ) 1 − α γ + 1 \frac{(1-\alpha)(\gamma \hat{c}+\gamma+1)}{1-\alpha^{\gamma+1}} 1−αγ+1(1−α)(γc^+γ+1)
采样和原分布的等价性证明
参考https://arxiv.org/pdf/2302.01318
其中需要一步代换证明下面两个公式等价:
原始公式
第一个公式:
= 1 − ∑ x ′ min ( p ( x ′ ) , q ( x ′ ) ) =1-\sum_{x^{\prime}}\min\left(p\left(x^{\prime}\right),q\left(x^{\prime}\right)\right) =1−x′∑min(p(x′),q(x′))
第二个公式:
= ∑ x ′ max ( 0 , q ( x ′ ) − p ( x ′ ) ) =\sum_{x^{\prime}}\max\left(0,q\left(x^{\prime}\right)-p\left(x^{\prime}\right)\right) =x′∑max(0,q(x′)−p(x′))
推导步骤
步骤 1: 应用 min 函数的恒等式
对于任何两个实数 a a a 和 b b b,都存在以下恒等关系:
min ( a , b ) = a − max ( 0 , a − b ) \min(a,b) = a - \max(0, a - b) min(a,b)=a−max(0,a−b)
令 b = p ( x ′ ) b = p(x') b=p(x′), a = q ( x ′ ) a = q(x') a=q(x′),得到:
min ( p ( x ′ ) , q ( x ′ ) ) = q ( x ′ ) − max ( 0 , q ( x ′ ) − p ( x ′ ) ) \min(p(x'),q(x')) = q(x') - \max(0, q(x') - p(x')) min(p(x′),q(x′))=q(x′)−max(0,q(x′)−p(x′))
步骤 2: 代入第一个公式
将恒等式代入原始公式:
1 − ∑ x ′ min ( p ( x ′ ) , q ( x ′ ) ) = 1 − ∑ x ′ [ q ( x ′ ) − max ( 0 , q ( x ′ ) − p ( x ′ ) ) ] \begin{aligned} &1 - \sum_{x^{\prime}} \min(p(x'),q(x')) \\ &= 1 - \sum_{x^{\prime}} \left[ q(x') - \max(0, q(x') - p(x')) \right] \end{aligned} 1−x′∑min(p(x′),q(x′))=1−x′∑[q(x′)−max(0,q(x′)−p(x′))]
步骤 3: 拆分求和运算
将求和符号分配到表达式内部:
= 1 − [ ∑ x ′ p ( x ′ ) − ∑ x ′ max ( 0 , p ( x ′ ) − q ( x ′ ) ) ] = 1 - \left[ \sum_{x^{\prime}} p(x') - \sum_{x^{\prime}} \max(0, p(x') - q(x')) \right] =1−[x′∑p(x′)−x′∑max(0,p(x′)−q(x′))]
= 1 − ∑ x ′ q ( x ′ ) + ∑ x ′ max ( 0 , q ( x ′ ) − p ( x ′ ) ) = 1 - \sum_{x^{\prime}} q(x') + \sum_{x^{\prime}} \max(0, q(x') - p(x')) =1−x′∑q(x′)+x′∑max(0,q(x′)−p(x′))
步骤 4: 应用概率分布性质
因为 p p p 和 q q q 都是概率分布函数,满足:
∑ x ′ p ( x ′ ) = 1 和 ∑ x ′ q ( x ′ ) = 1 \sum_{x^{\prime}} p(x') = 1 \quad \text{和} \quad \sum_{x^{\prime}} q(x') = 1 x′∑p(x′)=1和x′∑q(x′)=1
代入表达式:
= 1 − 1 + ∑ x ′ max ( 0 , q ( x ′ ) − p ( x ′ ) ) = 1 - 1 + \sum_{x^{\prime}} \max(0, q(x') - p(x')) =1−1+x′∑max(0,q(x′)−p(x′))
= ∑ x ′ max ( 0 , q ( x ′ ) − p ( x ′ ) ) = \sum_{x^{\prime}} \max(0, q(x') - p(x')) =x′∑max(0,q(x′)−p(x′))
得证