论文:https://arxiv.org/abs/2301.12597
代码:https://github.com/salesforce/LAVIS/tree/main/projects/blip2
Motivation:
1、端到端训练大规模视觉语言预训练模型成本比较高;
2、存在灾难性遗忘的问题。
解决方案:
1、利用现成的冻结的预训练图像编码器和冻结的预训练大语言模型进行视觉到文本的预训练;
2、使用轻量级的查询转换器Q-Former实现图像与文本的特征对齐。
模型组成:
1、预训练的图像编码器:从输入图片中提取视觉特征(不同分辨率的图像输出相同数量的特征),使用CLIP预训练的VIT结构;
2、预训练的大语言模型:用于文本生成,使用decoder-based LLM和encoder-decoder-based LLM;
3、Q-Former:图像和文本之间的桥梁。两阶段(表征学习和生成学习)预训练的Q-Former用于弥补模态差距,实现特征对齐。
Q-Former的两个预训练阶段:
1、vision-language representation learning stage(表征学习阶段):将Q-Former连接到冻结图像编码器,使用图像文本对进行预训练,令Q-Former学习到与文本信息最相关的视觉表征。
- Image Transformer(左):与冻结的图像编码器通过across-attention进行交互,提取视觉特征;
- Text Transformer(右):可以作为text encoder或text decoder;
- 可学习的Queries:作为Image Transformer的输入,通过Self Attention层实现不同query间的交互,通过共享的Self Attention层与输入文本进行交互。强制Queries提取有关文本的所有信息的视觉特征。
三个预训练任务,分别使用不同的Attention Mask和Loss,使用BERTbase权重初始化(Across Attention层随机初始化):
Loss计算同BLIP模型:
1、Image-Text Matching(图文匹配):二分类任务(匹配/不匹配),Bi-directional Self-Attention mask + ITM Loss
2、Image-Grounded Text Generation(图像引导的文本生成):Causal Self-Attention mask + IGT Loss (LM)
3、Image-Text Contrastive Learning(图文对比学习):Uni-modal Self-Attention mask + ITC Loss
2、vision-to-language generative learning stage(生成学习阶段):将Q-Former的输出连接到冻结的LLM来执行视觉到语言的生成学习,训练Q-Former,使其输出的视觉特征可以被LLM理解。
具体做法:Q-Former输出的query embeddings通过FC层投影到与LLM的text embedding相同的维度,然后将其添加到输入text embedding之前。将Q-Former提取到的视觉表征作为soft visual prompt。
由于Q-Former已被预先训练以提取语言信息的视觉表示,因此它有效地充当信息瓶颈,为LLM提供最有用的信息,同时删除无关的视觉信息。这减轻了LLM学习视觉语言特征对齐的负担,从而减轻了灾难性遗忘问题。
- decoder-based LLM:使用语言建模损失(language modeling loss)进行预训练,frozen LLM进行文本生成;
- encoder-decoder-based LLM:使用前缀语言建模损失(prefix language modeling loss)进行预训练,将text分为两部分,前半部分与视觉表征concat输入LLM编码器,后半部分作为LLM解码器的生成目标。
为什么不让LLM认识Query,而让Query变成LLM认识呢?
1、LLM模型的训练代价有点大;
2、从Prompt Learning的观点来看,目前多模态的数据量不足以保证LLM训练的更好,反而可能会让其丧失泛化性。如果不能让模型适应任务,那就让任务来适应模型。
代码分析:
1、Learned Queries初始化:
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size) # [1, 32, 768]
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) # 高斯权重初始化
2、Q-Former结构:
Qformer = BertLMHeadModel.from_pretrained(
"bert-base-uncased", config=encoder_config
)
def forward(self, samples):
image = samples["image"]
text = samples["text_input"]
image_embeds = self.ln_vision(self.visual_encoder(image)) # [1, 197, 768]
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
) # [1, 197]
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) # [1, 32, 768]
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds, # 与image_embeds做cross-attention
encoder_attention_mask=image_atts,
use_cache=True,
return_dict=True,
)
image_feats = F.normalize(
self.vision_proj(query_output.last_hidden_state), dim=-1 # [1, 32, 768] -> [1, 32, 256]
)
text_tokens = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(image.device) # [1, 32], 一个句子,32个token_id([cls] + tokens + [SEP], 不足补0)
text_output = self.Qformer.bert(
text_tokens.input_ids,
attention_mask=text_tokens.attention_mask,
return_dict=True,
) # [1, 32, 768]
text_feat = F.normalize(
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 # [1, 768] -> [1, 256]
)
具体任务实现:
1、 图文检索(Image-Text Retrieval)
图像文本检索不涉及语言生成,直接对第一阶段的预训练模型进行微调即可,不使用LLM。
具体而言,使用与预训练相同的目标(ITC、ITM和ITG)与Q-Former一起微调图像编码器。ITC和ITM损失对于图像文本检索至关重要,因为它们直接学习图像文本相似性。
2、图像字幕(Image Captioning)
微调BLIP-2模型,要求模型根据图像的视觉内容生成文本描述。具体做法是,使用提示“a photo of”作为LLM的初始输入,并在语言建模丢失的情况下训练模型生成标题。在微调过程中保持LLM冻结,并与图像编码器一起更新Q-Former的参数。
3、视觉问答(VQA)
给定带注释的VQA数据,微调Q-Former和图像编码器,同时保持LLM冻结。对开放式答案生成损失进行微调,其中LLM接收Q-Former的输出和问题作为输入,并被要求生成答案。
为了提取与问题更相关的图像特征,在问题上附加了Q-Former条件。具体做法是,将问题作为Q-Former的输入,通过自注意层与查询交互,引导Q-Former的交叉注意层关注信息更丰富的图像区域。
4、指示的零样本图像到文本生成(Instructed Zero-shot Image-to-Text Generation)
包括视觉对话、视觉知识推理、视觉共感推理、故事讲述、个性化图像到文本的生成等任务。