NLP:Transformer模型构建

发布于:2025-08-18 ⋅ 阅读:(14) ⋅ 点赞:(0)


前言:前面讲解了Transformer的各个部分,本文讲解Transformer模型整体构建。

简单来说,Transformer标准结构包括6个编码器和6个解码器,另外包括1个输入层和1个输出层

一、编码器和解码器的代码实现

# 定义EncoderDecoder类
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, source_embed, target_embed, generator):
        super().__init__()
        # encoder:编码器对象
        self.encoder = encoder
        # decoder:解码器的对象
        self.decoder = decoder
        # source_embed:源语言输入部分的对象:wordEmbedding+PositionEncoding
        self.source_embed = source_embed
        # target_embed:目标语言输入部分的对象:wordEmbedding+PositionEncoding
        self.target_embed = target_embed
        # generator:输出层对象
        self.generator = generator
    def forward(self, source, target, source_mask1, source_mask2, target_mask):
        # source:源语言的输入,形状--[batch_size, seq_len]-->[2, 4]
        # target:目标语言的输入,形状--[batch_size, seq_len]-->[2, 6]
        # source_mask1:padding mask:作用在编码器端多头自注意力机制-->[head, source_seq_len, source_seq_len]-->[8, 4, 4]
        # source_mask2:padding mask:作用在解码器端多头注意力机制-->[head, target_seq_len, source_seq_len]-->[8, 6, 4]
        # target_mask:sentence mask:作用在解码器端多头自注意力机制-->[head, target_seq_len, target_seq_len]-->[8, 6, 6]
        # 1.将原始的source源语言的输入,形状--[batch_size, seq_len]-->[2, 4]送入编码器输入部分变成--[2,4,512]
        # encode_word_embed:wordEmbedding+PositionEncoding
        encode_word_embed = self.source_embed(source)
        # 2. encode_word_embed以及source_mask1送入编码器得到编码之后的结果:encoder_output-->[2, 4, 512]
        encoder_output = self.encoder(encode_word_embed, source_mask1)
        # 3. target:目标语言的输入,形状--[batch_size, seq_len]-->[2, 6] 送入解码器输入部分变成--[2,6,512]
        decode_word_embed = self.target_embed(target)
        # 4. 将decode_word_embed,encoder_output,source_mask2,target_mask送入解码器
        decoder_output = self.decoder(decode_word_embed, encoder_output, source_mask2, target_mask)
        # 5.将decoder_output送入输出层
        output = self.generator(decoder_output)
        return output

二、实例化编码器解码器函数

def dm_transformer():
    # 1.实例化编码器对象
    # 实例化多头注意力机制的对象
    mha = MutiHeadAttention(embed_dim=512, head=8, dropout_p=0.1)
    # 实例化前馈全连接层对象
    ff = FeedForward(d_model=512, d_ff=1024)
    encoder_layer = EncoderLayer(size=512, self_atten=mha, ff=ff, dropout_p=0.1)
    encoder = Encoder(layer=encoder_layer, N=6)

    # 2.实例化解码器对象
    self_attn = copy.deepcopy(mha)
    src_attn = copy.deepcopy(mha)
    feed_forward = copy.deepcopy(ff)
    decoder_layer = DecoderLayer(size=512, self_attn=self_attn, src_attn=src_attn, feed_forward=feed_forward, dropout_p=0.1)
    decoder = Decoder(layer=decoder_layer, N=6)

    # 3.源语言输入部分的对象:wordEmbedding+PositionEncoding
    # 经过Embedding层
    vocab_size = 1000
    d_model = 512
    encoder_embed = Embeddings(vocab_size=vocab_size, d_model=d_model)
    # 经过位置编码器层(在位置编码器内部,我们其实已经融合来embed_x)
    dropout_p = 0.1
    encoder_pe = PositionEncoding(d_model=d_model, dropout_p=dropout_p)
    source_embed = nn.Sequential(encoder_embed, encoder_pe)
    # 4.目标语言输入部分的对象:wordEmbedding+PostionEncoding
    # 经过Embedding层
    decoder_embed = copy.deepcopy(encoder_embed)
    # 经过位置编码器层(在位置编码器内部,我们其实已经融合来embed_x)
    decoder_pe = copy.deepcopy(encoder_pe)
    target_embed = nn.Sequential(decoder_embed, decoder_pe)

    # 5.实例化输出对象
    generator = Generator(d_model=512, vocab_size=2000)

    # 6.实例化EncoderDecoder对象
    transformer = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)
    print(transformer)

    # 7.准备数据
    source = torch.tensor([[1, 2, 3, 4],
                           [2, 5, 6, 10]])
    target = torch.tensor([[1, 20, 3, 4, 19, 30],
                           [21, 5, 6, 10, 80,38]])
    source_mask1 = torch.zeros(8, 4, 4)
    source_mask2 = torch.zeros(8, 6, 4)
    target_mask = torch.zeros(8, 6, 6)
    result = transformer(source, target, source_mask1, source_mask2, target_mask)
    print(f'transformer模型最终的输出结果--》{result}')
    print(f'transformer模型最终的输出结果--{result.shape}')

三、代码运行结果

# 根据Transformer结构图构建的最终模型结构
EncoderDecoder(
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512)
            (1): Linear(in_features=512, out_features=512)
            (2): Linear(in_features=512, out_features=512)
            (3): Linear(in_features=512, out_features=512)
          )
          (dropout): Dropout(p=0.1)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=512, out_features=2048)
          (w_2): Linear(in_features=2048, out_features=512)
          (dropout): Dropout(p=0.1)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm(
            )
            (dropout): Dropout(p=0.1)
          )
          (1): SublayerConnection(
            (norm): LayerNorm(
            )
            (dropout): Dropout(p=0.1)
          )
        )
      )
      (1): EncoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512)
            (1): Linear(in_features=512, out_features=512)
            (2): Linear(in_features=512, out_features=512)
            (3): Linear(in_features=512, out_features=512)
          )
          (dropout): Dropout(p=0.1)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=512, out_features=2048)
          (w_2): Linear(in_features=2048, out_features=512)
          (dropout): Dropout(p=0.1)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm(
            )
            (dropout): Dropout(p=0.1)
          )
          (1): SublayerConnection(
            (norm): LayerNorm(
            )
            (dropout): Dropout(p=0.1)
          )
        )
      )
    )
    (norm): LayerNorm(
    )
  )
  (decoder): Decoder(
    (layers): ModuleList(
      (0): DecoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512)
            (1): Linear(in_features=512, out_features=512)
            (2): Linear(in_features=512, out_features=512)
            (3): Linear(in_features=512, out_features=512)
          )
          (dropout): Dropout(p=0.1)
        )
        (src_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512)
            (1): Linear(in_features=512, out_features=512)
            (2): Linear(in_features=512, out_features=512)
            (3): Linear(in_features=512, out_features=512)
          )
          (dropout): Dropout(p=0.1)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=512, out_features=2048)
          (w_2): Linear(in_features=2048, out_features=512)
          (dropout): Dropout(p=0.1)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm(
            )
            (dropout): Dropout(p=0.1)
          )
          (1): SublayerConnection(
            (norm): LayerNorm(
            )
            (dropout): Dropout(p=0.1)
          )
          (2): SublayerConnection(
            (norm): LayerNorm(
            )
            (dropout): Dropout(p=0.1)
          )
        )
      )
      (1): DecoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512)
            (1): Linear(in_features=512, out_features=512)
            (2): Linear(in_features=512, out_features=512)
            (3): Linear(in_features=512, out_features=512)
          )
          (dropout): Dropout(p=0.1)
        )
        (src_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512)
            (1): Linear(in_features=512, out_features=512)
            (2): Linear(in_features=512, out_features=512)
            (3): Linear(in_features=512, out_features=512)
          )
          (dropout): Dropout(p=0.1)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=512, out_features=2048)
          (w_2): Linear(in_features=2048, out_features=512)
          (dropout): Dropout(p=0.1)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm(
            )
            (dropout): Dropout(p=0.1)
          )
          (1): SublayerConnection(
            (norm): LayerNorm(
            )
            (dropout): Dropout(p=0.1)
          )
          (2): SublayerConnection(
            (norm): LayerNorm(
            )
            (dropout): Dropout(p=0.1)
          )
        )
      )
    )
    (norm): LayerNorm(
    )
  )
  (src_embed): Sequential(
    (0): Embeddings(
      (lut): Embedding(11, 512)
    )
    (1): PositionalEncoding(
      (dropout): Dropout(p=0.1)
    )
  )
  (tgt_embed): Sequential(
    (0): Embeddings(
      (lut): Embedding(11, 512)
    )
    (1): PositionalEncoding(
      (dropout): Dropout(p=0.1)
    )
  )
  (generator): Generator(
    (proj): Linear(in_features=512, out_features=11)
  )
)

如果代码有不懂,可参看此前文章,谢谢阅读,今天分享结束。


网站公告

今日签到

点亮在社区的每一天
去签到