1 前言
本期内容,我们讲Flow-GRPO,他将基于强化学习的GRPO用于Flow matching,并在多个测试指标上获得了巨大的突破,下面让我们来看一下
视频:Flow-GRPO:通过在线 RL 训练 Flow matching 模型
参考论文:Flow-GRPO: Training Flow Matching Models via Online RL
参考代码:Flow-GRPO:
Training Flow Matching Models via Online RL
2 引入
在Flow matching当中,已经可以取得相当不错的效果了,一些基于此开发的模型,如SD3.5的生成质量也相当不错。然而,与最先进的模型相比,SD3.5的指标质量仍然有待提高。比如,与GPT-4o相比,SD3.5显然落后一大截。
在NLP领域,将基于RL(强化学习)的方法引入其中已经证明可以取得相当不错的效果,该方法可以让模型的生成结果更加的趋近于人类的偏好,比如DPO、GRPO等等
强化学习除了应用于NLP领域,在CV领域中也逐渐大放异彩,而Flow-GRPO,就是将GRPO用于Flow matching当中。
3 Flow matching
先回顾一下Flow matching,假定存在 x 0 ∼ X 0 x_0\sim X_0 x0∼X0为真实的数据样本, x 1 ∼ X 1 x_1\sim X_1 x1∼X1为噪声样本,以Rectified flow为例,任意时刻的状态可以表示为
x t = ( 1 − t ) x 0 + t x 1 x_t = (1-t)x_0+tx_1 xt=(1−t)x0+tx1
其中 t ∈ [ 0 , 1 ] t\in [0,1] t∈[0,1],我们可通过训练得到一个近似向量场 v θ ( x t , t ) v_\theta(x_t,t) vθ(xt,t)
L ( θ ) = E t , x 0 , x 1 [ ∥ v − v θ ( x t , t ) ∥ 2 ] \mathcal{L}(\theta)=\mathbb{E}_{t,x_0,x_1}\left[ \Vert v - v_\theta(x_t,t) \Vert^2 \right] L(θ)=Et,x0,x1[∥v−vθ(xt,t)∥2]
其中,向量场 v = x 1 − x 0 v=x_1-x_0 v=x1−x0
4 方法
论文以SD3.5为例,将Flow-GRPO应用于T2I(文生图)当中。熟悉GRPO的小伙伴都知道,要使用GRPO的方法对Flow进行训练,要先解决ODE的问题:
- ODE无法在同一条件下生成多个样本,因此需要进行ODE到SDE的转化
4.1 GRPO
RL的优化目标一般为
max θ E ( s 0 , a 0 , ⋯ , s T , a T ) ∼ π 0 [ ∑ t = 0 T ( R ( s t , a t ) − β D K L ( π θ ( ⋅ ∣ s t ) ∣ ∣ π r e f ( ⋅ ∣ s t ) ) ) ] \max_\theta \mathbb{E}_{(s_0,a_0,\cdots,s_T,a_T)\sim \pi_0}\left[ \sum\limits_{t=0}^T\left( R(s_t,a_t)-\beta D_{KL}(\pi_\theta(\cdot | s_t)||\pi_{ref}(\cdot|s_t)) \right) \right] θmaxE(s0,a0,⋯,sT,aT)∼π0[t=0∑T(R(st,at)−βDKL(πθ(⋅∣st)∣∣πref(⋅∣st)))]
去噪过程可以表示为一个MDP,给定提示词c,Flow可以得到一组图像 { x 0 i } i = 1 G \{ x_0^i \}_{i=1}^G {x0i}i=1G,还有对应的一个采样轨迹 { ( x T i , x T − 1 i , ⋯ , x 0 i ) } i = 1 G \{ (x_T^i,x_{T-1}^i,\cdots,x_0^i) \}_{i=1}^G {(xTi,xT−1i,⋯,x0i)}i=1G,我们可通过组归一化来计算第i张图形的优势,即
A ^ t i = R ( x 0 i , c ) − mean ( { R ( x 0 i , c ) } i = i G ) std ( { R ( x 0 i , c ) } i = 1 G ) \hat A_t^i=\frac{R(x_0^i,c)-\text{mean}(\{R(x_0^i,c)\}_{i=i}^G)}{ \text{std}(\{R(x_0^i,c)\}_{i=1}^G)} A^ti=std({R(x0i,c)}i=1G)R(x0i,c)−mean({R(x0i,c)}i=iG)
最大化GRPO的优化目标
J Flow-GRPO ( θ ) = E c ∼ C , { x i } i = 1 G ∼ π θ old ( ⋅ ∣ c ) f ( r , A ^ , θ , ε , β ) \mathcal{J}_{\text{Flow-GRPO}}(\theta)=\mathbb{E}_{c\sim \mathcal{C},\{x^i\}_{i=1}^G\sim \pi_{\theta_{\text{old}}}(\cdot|c)}f(r,\hat A,\theta,\varepsilon,\beta) JFlow-GRPO(θ)=Ec∼C,{xi}i=1G∼πθold(⋅∣c)f(r,A^,θ,ε,β)
其中
f ( r , A ^ , θ , ε , β ) = 1 G ∑ i = 1 G 1 T ∑ t = 0 T − 1 ( min ( r t i ( θ ) A ^ t i , clip ( 1 − ε , 1 + ε ) A ^ t i ) − β D K L ( π θ ∣ ∣ π r e f ) ) , and r t i ( θ ) = p θ ( x t − 1 i ∣ x t i , c ) p θ o l d ( x t − 1 i ∣ x t i , c ) f(r,\hat A,\theta,\varepsilon,\beta) = \frac{1}{G}\sum\limits_{i=1}^G\frac{1}{T}\sum\limits_{t=0}^{T-1}\left( \min\left( r_t^i(\theta)\hat A_t^i,\text{clip}(1-\varepsilon,1+\varepsilon)\hat A_t^i \right) - \beta D_{KL}(\pi_\theta||\pi_{ref})\right),\\\text{and}\quad r_t^{i}(\theta)=\frac{p_\theta(x_{t-1}^i|x_t^i,c)}{p_{\theta_{old}}(x_{t-1}^i|x_t^i,c)} f(r,A^,θ,ε,β)=G1i=1∑GT1t=0∑T−1(min(rti(θ)A^ti,clip(1−ε,1+ε)A^ti)−βDKL(πθ∣∣πref)),andrti(θ)=pθold(xt−1i∣xti,c)pθ(xt−1i∣xti,c)
4.2 从 ODE 到 SDE
如上式可见,无论是计算优势函数,还是优化目标当中,都依赖于随机采样来得到不同的轨迹。而基于ODE的去噪过程显然是不满足这一要求的,为此,我们需要把去噪过程从ODE转变为SDE,这样就有了随机性。
那么如何将ODE转化为SDE呢?其实,我们可以得到下面的等式(稍后证明)
d x t = [ v t ( x t ) + σ t 2 2 t ( x t + ( 1 − t ) v t ( x t ) ) ] d t + σ t d w ˉ (1) d x_t = \left[ v_t(x_t) + \frac{\sigma_t^2}{2t}(x_t+(1-t)v_t(x_t)) \right]dt + \sigma_td\bar w\tag{1} dxt=[vt(xt)+2tσt2(xt+(1−t)vt(xt))]dt+σtdwˉ(1)
d w ˉ d\bar w dwˉ表示维纳过程增量, σ t \sigma_t σt是用于控制稳定程度的
可以看到,Eq.(1)仅仅依赖于向量场 v v v,我们完全可以使用学习到的近似向量场 v θ v_\theta vθ去表示他。我们可以使用任意一个数值求解器,来得到生成轨迹
如欧拉-丸山法
去噪过程为
x t + Δ t = x t + [ v θ ( x t , t ) + σ t 2 2 t ( x t + ( 1 − t ) v θ ( x t , t ) ) ] Δ t + σ t Δ t ε (2) x_{t+\Delta t} = x_t + \left[ v_{\theta}(x_t,t) + \frac{\sigma_t^2}{2t}(x_t + (1 - t )v_\theta(x_t,t)) \right]\Delta t + \sigma_t\sqrt{ \Delta t}\varepsilon\tag{2} xt+Δt=xt+[vθ(xt,t)+2tσt2(xt+(1−t)vθ(xt,t))]Δt+σtΔtε(2)
其中 ε ∼ N ( 0 , I ) , σ t = a t 1 − t \varepsilon \sim \mathcal{N}(0,I),\sigma_t = a\sqrt{\frac{t}{1 - t}} ε∼N(0,I),σt=a1−tt, a a a是控制噪声水平的超参数。
依据正态分布的性质可知,Eq.(2),也就是 π θ ( x t − 1 ∣ x t , c ) \pi_\theta(x_{t-1}|x_t,c) πθ(xt−1∣xt,c)服从正态分布,那么很显然,我们可以直接KL散度为
D K L ( π 0 ∣ ∣ π r e f ) = ∥ x ˉ t + Δ t , θ − x ˉ t + Δ t , r e f ∥ 2 2 σ t 2 Δ t = Δ t 2 ( σ t ( 1 − t ) 2 t + 1 σ t ) 2 ∥ v θ ( x t , t ) − v r e f ( x t , t ) ∥ 2 (3) D_{KL}(\pi_0||\pi_{ref})=\frac{\Vert \bar x_{t+\Delta t,\theta} - \bar x_{t + \Delta t ,ref} \Vert^2}{2\sigma_t^2\Delta t} = \frac{ \Delta t}{2}\left( \frac{\sigma_t(1-t)}{2t} +\frac{1}{\sigma_t} \right)^2\Vert v_\theta(x_t,t) - v_{ref}(x_t,t) \Vert^2\tag{3} DKL(π0∣∣πref)=2σt2Δt∥xˉt+Δt,θ−xˉt+Δt,ref∥2=2Δt(2tσt(1−t)+σt1)2∥vθ(xt,t)−vref(xt,t)∥2(3)
这里直接代入KL散度公式来计算即可
5 Denoising Reduction
为了生成高质量的图像,Flow matching通常需要很多的去噪步骤,这使得RL训练的数据收集成本非常高。
论文发现,在进行RL训练的时候,是不需要太多的采样步数的,而在推理的时候保持原始的采样步依然能够获取高质量的样本。
为此,以SD3.5为例,在进行RL训练的时候,令采样时间步T=10;而在推理的时候,保持SD3.5默认设置T=40。
6 模型图
模型的训练流程见下图:
首先,分别采样5个高斯白噪声 s 0 s_0 s0,将提示词“A photo of four cups”作为条件,使用SDE数值求解器采样(T=10)得到 s T s_T sT。然后将 s T s_T sT送进奖励模型,得到 R 1 , R 2 , R 3 , R 4 , ⋯ , R G R^1,R^2,R^3,R^4,\cdots,R^G R1,R2,R3,R4,⋯,RG作为奖励。用这些奖励根据上面的优势函数计算优势得到 A ^ 1 , A ^ 2 , A ^ 3 , A ^ 4 , ⋯ , A ^ G \hat A^1,\hat A^2,\hat A^3,\hat A^4,\cdots,\hat A^G A^1,A^2,A^3,A^4,⋯,A^G,最后送进Flow-GRPO的损失函数计算损失即可。
7 数学证明
7.1 Eq.(1)证明
要将ODE转换成对应的SDE,就要先从ODE开始,我们有
d x t = v t d t (4) dx_t = v_tdt\tag{4} dxt=vtdt(4)
依据先前讲过的SDE,我们有对应的方程
d x x = f SDE ( x t , t ) d t + σ t d w (5) dx_x = f_{\text{SDE}}(x_t,t)dt +\sigma_td w\tag{5} dxx=fSDE(xt,t)dt+σtdw(5)
我们需要求出 f SDE f_{\text{SDE}} fSDE和 v t v_t vt的关系式
依据Flow matching所提到的FP方程,Eq.(4)和Eq.(5)都有一个对应的连续性方程来表达概率密度路径 p t p_t pt。对于Eq.(5),就是对应的FP方程(证明过程见什么是Fokker-Planck方程),即
KaTeX parse error: Undefined control sequence: \part at position 2: \̲p̲a̲r̲t̲ ̲_tp_t(x) = -\na…
而Eq.(4)对应的连续性方程为:
∂ t p t ( x ) = − ∇ ⋅ [ v t ( x t , t ) p t ( x ) ] (7) \partial_t p_t(x) = -\nabla \cdot [v_t(x_t,t)p_t(x)]\tag{7} ∂tpt(x)=−∇⋅[vt(xt,t)pt(x)](7)
当 p t p_t pt和 v t v_t vt的关系满足Eq.(7),则我们说向量场 v v v能够生成对应的路径 p t p_t pt。Eq.(6)同理。
那么接下来就简单了,联立Eq.(6)和Eq.(7)
− ∇ ⋅ [ f SDE ( x t , t ) p t ( x ) ] + 1 2 ∇ 2 [ σ t 2 p t ( x ) ] = − ∇ ⋅ [ v t ( x t , t ) p t ( x ) ] (8) -\nabla \cdot [f_{\text{SDE}}(x_t,t)p_t(x)]+\frac{1}{2}\nabla^2[\sigma_t^2p_t(x)] =-\nabla \cdot [v_t(x_t,t)p_t(x)]\tag{8} −∇⋅[fSDE(xt,t)pt(x)]+21∇2[σt2pt(x)]=−∇⋅[vt(xt,t)pt(x)](8)
因为
∇ log p t ( x ) = 1 p t ( x ) ⋅ ∇ p t ( x ) → ∇ p t ( x ) = p t ( x ) ⋅ ∇ log p t ( x ) \nabla \log p_t(x) = \frac{1}{p_t(x)} \cdot \nabla p_t(x) \to \nabla p_t(x) = p_t(x)\cdot\nabla \log p_t(x) ∇logpt(x)=pt(x)1⋅∇pt(x)→∇pt(x)=pt(x)⋅∇logpt(x)
对Eq.(8)左侧第二项进行一下变化
∇ 2 [ σ t 2 p t ( x ) ] = σ t 2 ∇ 2 p t ( x ) = σ t 2 ∇ ⋅ ( ∇ p t ( x ) ) = σ t 2 ∇ ⋅ ( p t ( x ) ∇ log p t ( x ) ) \begin{aligned} \nabla^2[\sigma_t^2p_t(x)] = &\sigma_t^2\nabla^2p_t(x) \\=& \sigma_t^2\nabla\cdot (\nabla p_t(x)) \\= & \sigma_t^2\nabla \cdot (p_t(x)\nabla \log p_t(x)) \end{aligned} ∇2[σt2pt(x)]===σt2∇2pt(x)σt2∇⋅(∇pt(x))σt2∇⋅(pt(x)∇logpt(x))
所以Eq.(8)等于:
− ∇ ⋅ [ f SDE ( x t , t ) p t ( x ) ] + 1 2 σ t 2 ∇ ⋅ ( p t ( x ) ∇ log p t ( x ) ) = − ∇ ⋅ [ v t ( x t , t ) p t ( x ) ] − f SDE ( x t , t ) p t ( x ) + 1 2 σ t 2 p t ( x ) ∇ log p t ( x ) = − v t ( x t , t ) p t ( x ) f SDE ( x t , t ) p t ( x ) = v t ( x t , t ) p t ( x ) + 1 2 σ t 2 p t ( x ) ∇ log p t ( x ) f SDE ( x t , t ) = v t ( x t , t ) + 1 2 σ t 2 ∇ log p t ( x ) (9) \begin{aligned} -\nabla \cdot [f_{\text{SDE}}(x_t,t)p_t(x)]+\frac{1}{2}\sigma_t^2\nabla \cdot (p_t(x)\nabla \log p_t(x)) &=-\nabla \cdot [v_t(x_t,t)p_t(x)] \\ -f_{\text{SDE}}(x_t,t)p_t(x) + \frac{1}{2}\sigma_t^2p_t(x)\nabla \log p_t(x) &= - v_t(x_t,t)p_t(x) \\ f_{\text{SDE}}(x_t,t)p_t(x) &= v_t(x_t,t)p_t(x) + \frac{1}{2}\sigma_t^2p_t(x)\nabla \log p_t(x) \\ f_{\text{SDE}}(x_t,t) &= v_t(x_t,t) + \frac{1}{2}\sigma_t^2\nabla \log p_t(x) \end{aligned} \tag{9} −∇⋅[fSDE(xt,t)pt(x)]+21σt2∇⋅(pt(x)∇logpt(x))−fSDE(xt,t)pt(x)+21σt2pt(x)∇logpt(x)fSDE(xt,t)pt(x)fSDE(xt,t)=−∇⋅[vt(xt,t)pt(x)]=−vt(xt,t)pt(x)=vt(xt,t)pt(x)+21σt2pt(x)∇logpt(x)=vt(xt,t)+21σt2∇logpt(x)(9)
这样的话,我们就得到了 f SDE f_{\text{SDE}} fSDE和 v t v_t vt的关系式了
依据Score-Based Generative Modeling through Stochastic Differential Equations,正向过程Eq.(5)有对应的反向过程为
d x t = [ f ( x t , t ) − g 2 ( t ) ∇ log p t ( x t ) ] d t + g ( t ) d w ˉ (10) dx_t = [f(x_t,t)-g^2(t)\nabla\log p_t(x_t)]dt + g(t)d\bar w\tag{10} dxt=[f(xt,t)−g2(t)∇logpt(xt)]dt+g(t)dwˉ(10)
其中,在本篇文章中,我们是让 g ( t ) = σ t g(t) = \sigma_t g(t)=σt,将Eq.(9)代入至Eq.(10)
d x t = [ v t ( x t , t ) + 1 2 σ t 2 ∇ log p t ( x t ) − σ t 2 ∇ log p t ( x t ) ] d t + σ t d w ˉ d x t = [ v t ( x t , t ) − σ t 2 2 ∇ log p t ( x t ) ] d t + σ t d w ˉ (11) \begin{aligned} dx_t = & \left[v_t(x_t,t) + \frac{1}{2}\sigma_t^2\nabla \log p_t(x_t) - \sigma_t^2\nabla\log p_t(x_t)\right]dt + \sigma_td\bar w \\dx_t = & \left[v_t(x_t,t)-\frac{\sigma_t^2}{2}\nabla\log p_t(x_t)\right]dt + \sigma_td\bar w \end{aligned}\tag{11} dxt=dxt=[vt(xt,t)+21σt2∇logpt(xt)−σt2∇logpt(xt)]dt+σtdwˉ[vt(xt,t)−2σt2∇logpt(xt)]dt+σtdwˉ(11)
对于Eq.(11),已知 v t v_t vt,一旦 ∇ log p t ( x t ) \nabla \log p_t(x_t) ∇logpt(xt)也是已知的,那么就没有未知变量了,也就可以使用数值求解器生成样本了。因此我们还需要求解 ∇ log p t ( x t ) \nabla \log p_t(x_t) ∇logpt(xt)。
对于前向加噪过程,我们有 x t = α t x 0 + β t x 1 x_t = \alpha_t x_0 + \beta_t x_1 xt=αtx0+βtx1,在本期的Flow中,我们将加噪过程定义为 α t = 1 − t ; β = t \alpha_t = 1 - t;\beta = t αt=1−t;β=t, x t x_t xt服从的概率分布为(假设一维的情况)
p t ∣ 0 ( x t ∣ x 0 ) = N ( x t ∣ a t x 0 , β t 2 I ) = 1 β t 2 π exp { − ( x t − a t x 0 ) 2 2 β t 2 } p_{t|0}(x_t|x_0) = \mathcal{N}(x_t|a_tx_0,\beta_t^2I) = \frac{1}{\beta_t\sqrt{2\pi}}\exp\{-\frac{(x_t-a_tx_0)^2}{2\beta_t^2}\} pt∣0(xt∣x0)=N(xt∣atx0,βt2I)=βt2π1exp{−2βt2(xt−atx0)2}
其对数结果为
log p t ∣ 0 ( x t ∣ x 0 ) = log ( 1 β t 2 π exp { − ( x t − a t x 0 ) 2 2 β t 2 } ) = log 1 β t 2 π − ( x t − a t x 0 ) 2 2 β t 2 \begin{aligned} \log p_{t|0}(x_t|x_0) = &\log \left( \frac{1}{\beta_t\sqrt{2\pi}}\exp\{-\frac{(x_t-a_tx_0)^2}{2\beta_t^2}\} \right) \\= &\log \frac{1}{\beta_t\sqrt{2\pi}} -\frac{(x_t-a_tx_0)^2}{2\beta_t^2} \end{aligned} logpt∣0(xt∣x0)==log(βt2π1exp{−2βt2(xt−atx0)2})logβt2π1−2βt2(xt−atx0)2
所以
∇ log p t ∣ 0 ( x t ∣ x 0 ) = − x t − α t x 0 β t 2 = β t x 1 β t 2 = − x 1 β t \nabla\log p_{t|0}(x_t|x_0) = -\frac{x_t - \alpha_tx_0}{\beta_t^2} = \frac{\beta_tx_1}{\beta_t^2} = -\frac{x_1}{\beta_t} ∇logpt∣0(xt∣x0)=−βt2xt−αtx0=βt2βtx1=−βtx1
因此
∇ log p t ( x t ) = 1 p t ( x t ) ∇ p t ( x t ) = 1 p t ( x t ) ∇ ∫ p t , 0 ( x t , x 0 ) d x 0 = 1 p t ( x t ) ∫ ∇ p t , 0 ( x t , x 0 ) d x 0 = 1 p t ( x t ) ∫ ∇ [ p t ∣ 0 ( x t ∣ x 0 ) p 0 ( x 0 ) ] d x 0 = 1 p t ( x t ) ∫ p 0 ( x 0 ) ∇ p t ∣ 0 ( x t ∣ x 0 ) d x 0 = 1 p t ( x t ) ∫ p 0 ( x 0 ) ⋅ p t ∣ 0 ( x t ∣ x 0 ) ∇ log p t ∣ 0 ( x t ∣ x 0 ) d x 0 = 1 p t ( x t ) ∫ p t , 0 ( x t , x 0 ) ∇ log p t ∣ 0 ( x t ∣ x 0 ) d x 0 = 1 p t ( x t ) ∫ p 0 ∣ t ( x 0 ∣ x t ) p t ( x t ) ∇ log p t ∣ 0 ( x t ∣ x 0 ) d x 0 = ∫ p 0 ∣ t ( x 0 ∣ x t ) ∇ log p t ∣ 0 ( x t ∣ x 0 ) d x 0 = ∫ x 0 ∫ x 1 p 0 ∣ t ( x 0 , x 1 ∣ x t ) d x 1 ∇ log p t ∣ 0 ( x t ∣ x 0 ) d x 0 = ∫ x 0 ∫ x 1 p 0 ∣ t ( x 0 , x 1 ∣ x t ) ∇ log p t ∣ 0 ( x t ∣ x 0 ) d x 1 d x 0 = E [ ∇ log p t ∣ 0 ( x t ∣ x 0 ) ∣ x t ] = E [ − x 1 β t ∣ x t ] = − 1 β t E [ x 1 ∣ x t ] (12) \begin{aligned} \nabla \log p_t(x_t) = & \frac{1}{p_t(x_t)}\nabla p_t(x_t) \\ = & \frac{1}{p_t(x_t)}\nabla\int p_{t,0}(x_t,x_0)dx_0 \\ = & \frac{1}{p_t(x_t)}\int \nabla p_{t,0}(x_t,x_0)dx_0 \\ = & \frac{1}{p_t(x_t)}\int \nabla \left[p_{t|0}(x_t|x_0)p_0(x_0)\right]dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_0(x_0) \nabla p_{t|0}(x_t|x_0)dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_0(x_0) \cdot p_{t|0}(x_t|x_0)\nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_{t,0}(x_t,x_0) \nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_{0|t}(x_0|x_t)p_t(x_t) \nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \int p_{0|t}(x_0|x_t)\nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \int_{x_0}\int_{x_1} p_{0|t}(x_0,x_1|x_t)dx_1\nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \int_{x_0}\int_{x_1} p_{0|t}(x_0,x_1|x_t) \nabla\log p_{t|0}(x_t|x_0) dx_1 dx_0 \\ = & \mathbb{E}\left[ \nabla \log p_{t|0}(x_t|x_0) |x_t\right] \\ = & \mathbb{E}\left[ -\frac{x_1}{\beta_t} |x_t\right] \\ = & -\frac{1}{\beta_t}\mathbb{E}\left[ x_1|x_t\right] \end{aligned}\tag{12} ∇logpt(xt)==============pt(xt)1∇pt(xt)pt(xt)1∇∫pt,0(xt,x0)dx0pt(xt)1∫∇pt,0(xt,x0)dx0pt(xt)1∫∇[pt∣0(xt∣x0)p0(x0)]dx0pt(xt)1∫p0(x0)∇pt∣0(xt∣x0)dx0pt(xt)1∫p0(x0)⋅pt∣0(xt∣x0)∇logpt∣0(xt∣x0)dx0pt(xt)1∫pt,0(xt,x0)∇logpt∣0(xt∣x0)dx0pt(xt)1∫p0∣t(x0∣xt)pt(xt)∇logpt∣0(xt∣x0)dx0∫p0∣t(x0∣xt)∇logpt∣0(xt∣x0)dx0∫x0∫x1p0∣t(x0,x1∣xt)dx1∇logpt∣0(xt∣x0)dx0∫x0∫x1p0∣t(x0,x1∣xt)∇logpt∣0(xt∣x0)dx1dx0E[∇logpt∣0(xt∣x0)∣xt]E[−βtx1∣xt]−βt1E[x1∣xt](12)
对于向量场 v v v,在我们之前的表达式中,是有 v t = x 1 − x 0 v_t = x_1 - x_0 vt=x1−x0。然而,由于路径存在交叉点,所以我们之前说过,我们学习到的 v θ v_\theta vθ其实并不等于 v t v_t vt,而是 v t v_t vt的数学期望。我们可以通过以下来证明:
L = ∫ 0 1 E x 0 , x 1 [ ∥ x 1 − x 0 − v θ ( x t , t ) ∥ 2 ] = ∫ 0 1 E x 0 , x 1 [ ∥ x 1 − x 0 ∥ 2 + ∣ ∣ v θ ( x t , t ) ∣ ∣ 2 − 2 ( x 1 − x 0 ) T v θ ( x t , t ) ] d t = ∫ 0 1 { E x 0 , x 1 [ ∣ ∣ v θ ( x t , t ) ∣ ∣ 2 ] − 2 E x 0 , x 1 [ ( x 1 − x 0 ) T v θ ( x t , t ) ] } d t + C = ∫ 0 1 { E x t [ ∣ ∣ v θ ( x t , t ) ∣ ∣ 2 ] − 2 E x 0 , x 1 [ ( x 1 − x 0 ) T v θ ( x t , t ) ] } d t + C (13) \begin{aligned} \mathcal{L} = & \int_0^1 \mathbb{E}_{x_0,x_1}\left[ \Vert x_1-x_0 - v_\theta(x_t,t)\Vert^2 \right] \\ = & \int_0^1 \mathbb{E}_{x_0,x_1}\left[ \Vert x_1-x_0 \Vert^2 + ||v_\theta(x_t,t)||^2- 2(x_1 - x_0)^Tv_\theta(x_t,t) \right]dt \\ = & \int_0^1 \left\{\mathbb{E}_{x_0,x_1}\left[ ||v_\theta(x_t,t)||^2\right] -2\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) \right]\right\}dt + C \\ = & \int_0^1 \left\{\mathbb{E}_{x_t}\left[ ||v_\theta(x_t,t)||^2\right] -2\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) \right]\right\}dt + C \end{aligned}\tag{13} L====∫01Ex0,x1[∥x1−x0−vθ(xt,t)∥2]∫01Ex0,x1[∥x1−x0∥2+∣∣vθ(xt,t)∣∣2−2(x1−x0)Tvθ(xt,t)]dt∫01{Ex0,x1[∣∣vθ(xt,t)∣∣2]−2Ex0,x1[(x1−x0)Tvθ(xt,t)]}dt+C∫01{Ext[∣∣vθ(xt,t)∣∣2]−2Ex0,x1[(x1−x0)Tvθ(xt,t)]}dt+C(13)
第一项是因为给定 x 0 , x 1 x_0,x_1 x0,x1,有 x t = t x 1 + ( 1 − t ) x 0 x_t = tx_1 + (1-t)x_0 xt=tx1+(1−t)x0,所以可以直接写成关于 x t x_t xt的数学期望。
第二项我们可以继续变化,由全期望公式: E Y = E X [ E Y ( Y ∣ X ) ] \mathbb{E}Y = \mathbb{E}_X[\mathbb{E}_Y(Y|X)] EY=EX[EY(Y∣X)],可得
E x 0 , x 1 [ ( x 1 − x 0 ) T v θ ( x t , t ) ] = E x t [ E x 0 , x 1 [ ( x 1 − x 0 ) T v θ ( x t , t ) ∣ x t ] ] = E x t [ E x 0 , x 1 [ ( x 1 − x 0 ) ∣ x t ] T v θ ( x t , t ) ] \begin{aligned} \mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) \right] = & \mathbb{E}_{x_t}[\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) |x_t\right]] \\ = & \mathbb{E}_{x_t}[\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0) |x_t\right]^Tv_\theta(x_t,t)] \end{aligned} Ex0,x1[(x1−x0)Tvθ(xt,t)]==Ext[Ex0,x1[(x1−x0)Tvθ(xt,t)∣xt]]Ext[Ex0,x1[(x1−x0)∣xt]Tvθ(xt,t)]
所以Eq.(13)为
L = ∫ 0 1 { E x t [ ∣ ∣ v θ ( x t , t ) ∣ ∣ 2 ] − 2 E x t [ E x 0 , x 1 [ ( x 1 − x 0 ) ∣ x t ] T v θ ( x t , t ) ] } d t + C = ∫ 0 1 E x t [ ∥ E x 0 , x 1 [ x 1 − x 0 ∣ x t ] − v θ ( x t , t ) ∥ 2 ] d t + C ′ (14) \begin{aligned} \mathcal{L} = &\int_0^1 \left\{\mathbb{E}_{x_t}\left[ ||v_\theta(x_t,t)||^2\right] -2\mathbb{E}_{x_t}[\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0) |x_t\right]^Tv_\theta(x_t,t)]\right\}dt + C \\ = & \int_0^1 \mathbb{E}_{x_t}\left[ \Vert \mathbb{E}_{x_0,x_1}[x_1-x_0|x_t] - v_\theta(x_t,t) \Vert^2 \right]dt +C' \end{aligned}\tag{14} L==∫01{Ext[∣∣vθ(xt,t)∣∣2]−2Ext[Ex0,x1[(x1−x0)∣xt]Tvθ(xt,t)]}dt+C∫01Ext[∥Ex0,x1[x1−x0∣xt]−vθ(xt,t)∥2]dt+C′(14)
此时我们不难看出,我们所学习到的 v θ ( x t , t ) = E x 0 , x 1 [ x 1 − x 0 ∣ x t ] v_\theta(x_t,t) = \mathbb{E}_{x_0,x_1}[x_1-x_0|x_t] vθ(xt,t)=Ex0,x1[x1−x0∣xt]
我们继续转化
v θ ( x t , t ) = E x 0 , x 1 [ x 1 − x 0 ∣ x t ] = E x 0 , x 1 [ x 1 ∣ x t ] − E x 0 , x 1 [ x 0 ∣ x t ] = E x 0 , x 1 [ x 1 ∣ x t ] − E x 0 , x 1 [ x t − t x 1 1 − t ∣ x t ] = E x 0 , x 1 [ x 1 ∣ x t ] − E x 0 , x 1 [ x t 1 − t ∣ x t ] + E x 0 , x 1 [ t x 1 1 − t ∣ x t ] = E x 0 , x 1 [ x 1 ∣ x t ] − x t 1 − t + t 1 − t E x 0 , x 1 [ x 1 ∣ x t ] = − x t 1 − t + 1 1 − t E x 0 , x 1 [ x 1 ∣ x t ] = − x t 1 − t + 1 1 − t ⋅ ( − β t ∇ log p t ( x t ) ) = − x t 1 − t − t 1 − t ⋅ ∇ log p t ( x t ) \begin{aligned} v_\theta(x_t,t) = & \mathbb{E}_{x_0,x_1}[x_1-x_0|x_t] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] - \mathbb{E}_{x_0,x_1}[x_0|x_t] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] - \mathbb{E}_{x_0,x_1}\left[\frac{x_t-tx_1}{1-t}|x_t\right] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] - \mathbb{E}_{x_0,x_1}\left[\frac{x_t}{1-t}|x_t\right] + \mathbb{E}_{x_0,x_1}\left[\frac{tx_1}{1-t}|x_t\right] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] -\frac{x_t}{1-t}+ \frac{t}{1-t}\mathbb{E}_{x_0,x_1}\left[x_1|x_t\right] \\ = & -\frac{x_t}{1-t} + \frac{1}{1-t}\mathbb{E}_{x_0,x_1}\left[x_1|x_t\right] \\ = & -\frac{x_t}{1-t} + \frac{1}{1-t}\cdot (-\beta_t\nabla \log p_t(x_t)) \\ = & -\frac{x_t}{1-t} - \frac{t}{1-t}\cdot \nabla \log p_t(x_t) \end{aligned} vθ(xt,t)========Ex0,x1[x1−x0∣xt]Ex0,x1[x1∣xt]−Ex0,x1[x0∣xt]Ex0,x1[x1∣xt]−Ex0,x1[1−txt−tx1∣xt]Ex0,x1[x1∣xt]−Ex0,x1[1−txt∣xt]+Ex0,x1[1−ttx1∣xt]Ex0,x1[x1∣xt]−1−txt+1−ttEx0,x1[x1∣xt]−1−txt+1−t1Ex0,x1[x1∣xt]−1−txt+1−t1⋅(−βt∇logpt(xt))−1−txt−1−tt⋅∇logpt(xt)
把 ∇ \nabla ∇单独放等式左侧可得
∇ log p t ( x t ) = − x t − 1 − t t v θ ( x t , t ) \nabla\log p_t(x_t) = -\frac{x}{t}-\frac{1-t}{t}v_\theta(x_t,t) ∇logpt(xt)=−tx−t1−tvθ(xt,t)
把它代入到Eq.(11)可得最终的表达式
d x t = [ v t ( x t , t ) − σ t 2 2 ( − x t − 1 − t t v θ ( x t , t ) ) ] d t + σ t d w ˉ d x t = [ v t ( x t , t ) + σ t 2 2 t ( x + ( 1 − t ) v θ ( x t , t ) ) ] d t + σ t d w ˉ (15) \begin{aligned} dx_t = & \left[v_t(x_t,t)-\frac{\sigma_t^2}{2}\left( -\frac{x}{t}-\frac{1-t}{t}v_\theta(x_t,t) \right)\right]dt + \sigma_td\bar w\\ dx_t = & \left[v_t(x_t,t)+\frac{\sigma_t^2}{2t}\left( x+(1-t)v_\theta(x_t,t) \right)\right]dt + \sigma_td\bar w \end{aligned}\tag{15} dxt=dxt=[vt(xt,t)−2σt2(−tx−t1−tvθ(xt,t))]dt+σtdwˉ[vt(xt,t)+2tσt2(x+(1−t)vθ(xt,t))]dt+σtdwˉ(15)
至此得证
8 参考
[1] 深入理解Rectified Flow,完善统一扩散框架 - 知乎
9 结束
好了,本期内容到此为止了,如有问题,还望指出,阿里嘎多!