fairseq框架使用记录

发布于:2024-06-10 ⋅ 阅读:(149) ⋅ 点赞:(0)

sh命令

cmd="fairseq-train data-bin/$data_dir
  --save-dir $save_dir
  --distributed-world-size $gpu_num -s $src_lang -t $tgt_lang
  --arch $arch
  --dropout $dropout
  --criterion $criterion --label-smoothing 0.1
  --task mmt_vqa
  --optimizer adam --adam-betas '(0.9, 0.98)'
  --lr $lr --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates $warmup
  --max-tokens $max_tokens --update-freq $update_freq
  --share-all-embeddings
  --find-unused-parameters
  --skip-invalid-size-inputs-valid-test
  --patience $keep_last_epochs
  --keep-last-epochs $keep_last_epochs
  --image-name-dir data-bin/$data_dir
  --ptm-name $vision_model_name
  --vision-model $vision_model_name
  --weight $weight
  --source-sentence-dir data-bin"

主要的执行流程–task

fairseq/tasks/mmt_vqa.py

模型–arch

fairseq/models/transformer/transformer_mmt_vqa_legacy.py

@register_model_architecture("transformer_mmt_vqa", "transformer_mmt_vqa_2sa_2decoder")
def transformer_mmt_vqa_2sa_2decoder(args):
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 128)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
    args.encoder_layers = getattr(args, "encoder_layers", 4)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128)
    args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 256)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
    args.decoder_layers = getattr(args, "decoder_layers", 4)
    mmt_vqa_base_architecture(args)

DualSAEncoder

在这里插入图片描述
VQA和MMT融合图像特征时 提取特征的encoder不一样、融合时选择的注意力也不一样、融合的门控参数也不一样。
具体的特征融合:
在这里插入图片描述

tensor.eq()/tensor.ne()

对两个张量Tensor进行逐元素的比较,若相同位置的两个元素相同,则返回True;若不同,返回False。eq(等于)/ne(不等)

具体流程主要就是
对query得到query和mask_query的特征 同样的流程得到text和mask_text特征
然后使用预训练模型提取图像特征image_features
使用选择注意力机制将text和img特征进行融合"sum"得到融合后的text特征
同样得到融合后的query特征

DualLayersDecoder

vqa_decoder:获得之前的特征向量 得到attn_mask获得交叉注意力 得到特征
text和vqa一样 最后取多头注意力的均值 投影到固定维度(单词的字典长度)
vqa做vqa的 text做text的 没有交集 就是一般的transformer decoder

TransformerMMTVQAModel

就是encoder 然后把encoder的输出作为输入+prev_output放进decoder得到decoder的结果。

损失函数–criterion

fairseq/criterions/label_smoothed_cross_entropy_mmt_vqa.py
在计算损失的时候加上了vqa损失


网站公告

今日签到

点亮在社区的每一天
去签到