将多个 TransformerEncoderLayer 层堆叠起来,形成一个完整的 Transformer 编码器

发布于:2023-09-16 ⋅ 阅读:(116) ⋅ 点赞:(0)

该函数作用是将多个 TransformerEncoderLayer 层堆叠起来,形成一个完整的 Transformer 编码器。以下是这个类的主要部分的解释:

  • encoder_layer: 这是一个 TransformerEncoderLayer 类的实例,表示编码器层的构建模块。编码器由多个这样的层叠加而成。

  • num_layers: 这是编码器中的子编码器层数。也就是说,编码器由多少个 encoder_layer 堆叠而成。

  • norm: 这是可选的层归一化组件,用于在编码器的输出上应用层归一化。

  • forward 函数:这个函数执行编码器的前向传播过程。它接受输入序列 src,以及可选的掩码 mask 和序列键掩码 src_key_padding_mask。然后,它迭代遍历每个子编码器层(由 self.layers 组成),并将输入序列 src 传递给每一层。最后,如果指定了层归一化组件 norm,则应用层归一化并返回输出。

这个类的主要作用是组装多个编码器层,使得它们可以一层一层地处理输入序列,并生成编码器的输出。这个输出通常用作后续任务的输入,例如序列到序列任务、文本分类等。

class TransformerEncoder(Module):
    r"""TransformerEncoder is a stack of N encoder layers

    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = transformer_encoder(src)
    """
    __constants__ = ['norm']

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layers in turn.

        Args:
            src: the sequence to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = src

        for mod in self.layers:
            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

本文含有隐藏内容,请 开通VIP 后查看

网站公告

今日签到

点亮在社区的每一天
去签到