VQVAE:Neural Discrete Representation Learning

发布于:2024-07-01 ⋅ 阅读:(13) ⋅ 点赞:(0)

论文名称:Neural Discrete Representation Learning
开源地址
发表时间:NIPS2017
作者及组织:Aaron van den Oord,Oriol Vinyals和Koray Kavukcuoglu, 来自DeepMind。

1、VAE

  简单回顾下VAE的损失函数,ELBO的下界为:
L o w e r B o u n d = E q φ ( z ∣ x ) [ l o g p θ ( x ∣ z ) ] − D K L ( q φ ( z ∣ x ) ∣ ∣ p ( z ) ) \begin{equation} Lower Bound =E_{q_\varphi(z|x)}[logp_\theta(x|z)] - D_{KL}(q_\varphi(z|x)||p(z)) \tag{0} \end{equation} LowerBound=Eqφ(zx)[logpθ(xz)]DKL(qφ(zx)∣∣p(z))(0)
 其中第一项为解码器的重构损失(regression loss) ;第二项为正则项,用KL散度来使Encoder----后验概率 q φ ( z ∣ x ) q_\varphi(z|x) qφ(zx) 和 先验 p ( z ) p(z) p(z) 分布近似,通常 p ( z ) p(z) p(z) 假设为多元标准正太分布,该项主要防止VAE坍塌到一个点,毕竟是生成模型。
 而VQVAE和VAE主要不同:Encoder输出是离散的,而不是连续的隐变量z。

1、方法

1.1.模型结构

在这里插入图片描述

 为了实现离散化编码,VQVAE引入了一个可学习的codebook,即上图中的EmbeddingSpace。大概说下流程:输入一张图像,经过CNN得到 z e ( x ) ∈ R H ∗ W ∗ D z_e(x) \in \mathbb{R}^{H*W*D} ze(x)RHWD ,然后计算 z e z_e ze 中每条特征向量跟codebook的最接近的向量的索引,得到 q ( z ∣ x ) ∈ R H ∗ W q(z|x) \in \mathbb{R}^{H*W} q(zx)RHW , 然后用codebook中向量 e i e_i ei 来替换 z e ( x ) z_e(x) ze(x) 得到 z q ( x ) z_q(x) zq(x) 。最后经过Decoder得到 x x x

1.2.训练

 先说下总体损失函数,其实跟VAE的损失函数类似:
L = l o g p ( x ∣ z q ( x ) ) + ∣ ∣ s g [ z e ( x ) ] − e ∣ ∣ 2 2 + β ∣ ∣ z e ( x ) − s g [ e ] ∣ ∣ 2 2 \begin{equation} L = logp(x|z_q(x)) + ||sg[z_e(x)]- e|| ^2_2 + \beta||z_e(x)-sg[e]||^2_2 \tag{1} \end{equation} L=logp(xzq(x))+∣∣sg[ze(x)]e22+β∣∣ze(x)sg[e]22(1)

 其中第一项就是VAE中的重构损失,但有个问题:在用L2 Loss计算重构损失后,反向传播时,由于在codebook中argmin这个操作是不可导的,这样就优化不了Encoder,于是本文直接将 z q ( x ) z_q(x) zq(x) 节点的梯度拷贝给了 z e ( x ) z_e(x) ze(x) ,使得反向传播得以继续。具体的表达式如下:
l o g p ( x ∣ z q ( x ) ) = ∣ ∣ x − d e c o d e r ( z e ( x ) + s g ( z q ( x ) − z e ( x ) ) ) ∣ ∣ 2 2 \begin{equation} logp(x|z_q(x)) = ||x-decoder(z_e(x)+sg(z_q(x)-z_e(x)))||_2^2 \tag{2} \end{equation} logp(xzq(x))=∣∣xdecoder(ze(x)+sg(zq(x)ze(x)))22(2)
 式中的 s g sg sg 表示 .detach() 操作,由于VQVAE多了一个可学习的codebook,而重构损失并没有梯度传过去。因此损失第二项就是让 e e e 逼近 z e ( x ) z_e(x) ze(x) ,这项仅更新codebook。

  由于训练过程中,Encoder相较于codebook,肯定易于优化,也就是Encoder收敛快,而codebook收敛慢 ,为了让Encoder别距离codebook太远,于是增加了第三项损失,让 z e ( x ) z_e(x) ze(x) 逼近 e e e

 在回过头来跟VAE的式子比较下 ,发现缺少了KL散度项:这是因为在VQVAE中,在根据 x x x 取得 e e e 的概率非0即1: q ( z = e ∣ x ) = 1 , q ( z = o t h e r ∣ x = 0 ) q(z=e|x)=1,q(z=other|x=0) q(z=ex)=1,q(z=otherx=0) ,相当于二项分布,同时假设 p ( z ) p(z) p(z) 是均匀分布,两个均匀分布的KL散度是常数,在损失中可忽略。

1.3.生成

 在训练集上训练完VQVAE后,VQVAE学习到的是一个有效的低维度的离散表示。然后将VQVAE置为推理阶段,用自回归模型PixCNN来拟合 q ( z ∣ x ) q(z|x) q(zx) ,训练完成后,PixCNN就能生成有意义的索引矩阵,然后去codebook中拿到对应的张量,送去VQVAE的Decoder中解码生成图像。

2、实验

  生成的小图还是可以的。
在这里插入图片描述

思考

  替换更强的自回归模型Transformer也就是后来VQGAN的工作了。