sh命令
cmd="fairseq-train data-bin/$data_dir
--save-dir $save_dir
--distributed-world-size $gpu_num -s $src_lang -t $tgt_lang
--arch $arch
--dropout $dropout
--criterion $criterion --label-smoothing 0.1
--task mmt_vqa
--optimizer adam --adam-betas '(0.9, 0.98)'
--lr $lr --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates $warmup
--max-tokens $max_tokens --update-freq $update_freq
--share-all-embeddings
--find-unused-parameters
--skip-invalid-size-inputs-valid-test
--patience $keep_last_epochs
--keep-last-epochs $keep_last_epochs
--image-name-dir data-bin/$data_dir
--ptm-name $vision_model_name
--vision-model $vision_model_name
--weight $weight
--source-sentence-dir data-bin"
主要的执行流程–task
fairseq/tasks/mmt_vqa.py
模型–arch
fairseq/models/transformer/transformer_mmt_vqa_legacy.py
@register_model_architecture("transformer_mmt_vqa", "transformer_mmt_vqa_2sa_2decoder")
def transformer_mmt_vqa_2sa_2decoder(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 128)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.encoder_layers = getattr(args, "encoder_layers", 4)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 256)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.decoder_layers = getattr(args, "decoder_layers", 4)
mmt_vqa_base_architecture(args)
DualSAEncoder
VQA和MMT融合图像特征时 提取特征的encoder不一样、融合时选择的注意力也不一样、融合的门控参数也不一样。
具体的特征融合:
tensor.eq()/tensor.ne()
对两个张量Tensor进行逐元素的比较,若相同位置的两个元素相同,则返回True;若不同,返回False。eq(等于)/ne(不等)
具体流程主要就是
对query得到query和mask_query的特征 同样的流程得到text和mask_text特征
然后使用预训练模型提取图像特征image_features
使用选择注意力机制将text和img特征进行融合"sum"得到融合后的text特征
同样得到融合后的query特征
DualLayersDecoder
vqa_decoder:获得之前的特征向量 得到attn_mask获得交叉注意力 得到特征
text和vqa一样 最后取多头注意力的均值 投影到固定维度(单词的字典长度)
vqa做vqa的 text做text的 没有交集 就是一般的transformer decoder
TransformerMMTVQAModel
就是encoder 然后把encoder的输出作为输入+prev_output放进decoder得到decoder的结果。
损失函数–criterion
fairseq/criterions/label_smoothed_cross_entropy_mmt_vqa.py
在计算损失的时候加上了vqa损失