13.编码器的结构

发布于:2025-04-22 ⋅ 阅读:(80) ⋅ 点赞:(0)

从入门AI到手写Transformer-13.编码器的结构


整理自视频 老袁不说话

13.编码器的结构

T r a n s f o r m e r E n c o d e r : 输入 [ b , n ] TransformerEncoder:输入[b,n] TransformerEncoder:输入[b,n]

  • E m b e d d i n g : − > [ b , n , d ] Embedding:->[b,n,d] Embedding:>[b,n,d]
  • P o s i t i o n a l E n c o d e r : − > [ b , n , d ] PositionalEncoder:->[b,n,d] PositionalEncoder:>[b,n,d]
  • D r o p o u t : − > [ b , n , d ] Dropout:->[b,n,d] Dropout:>[b,n,d]
  • E n c o d e r B l o c k : [ b , n , d ] − > [ b , n , d ] EncoderBlock:[b,n,d]->[b,n,d] EncoderBlock:[b,n,d]>[b,n,d] 重复N次
    • M u l t i h e a d A t t e n t i o n : 3 ∗ [ b , n , d ] − > [ b , n , d ] MultiheadAttention:3*[b,n,d]->[b,n,d] MultiheadAttention:3[b,n,d]>[b,n,d]
    • D r o p o u t : [ b , n , d ] − > [ b , n , d ] Dropout:[b,n,d]->[b,n,d] Dropout:[b,n,d]>[b,n,d]
    • A d d N o r m : 2 ∗ [ b , n , d ] ( D r o u p o u t 输出, M u l t i h e a d A t t e n t i o n 输入 ) − > [ b , n , d ] AddNorm:2*[b,n,d](Droupout输出,MultiheadAttention输入)->[b,n,d] AddNorm:2[b,n,d](Droupout输出,MultiheadAttention输入)>[b,n,d]
    • F F N : [ b , n , d ] − > [ b , n , d ] FFN:[b,n,d]->[b,n,d] FFN:[b,n,d]>[b,n,d]
    • D r o p o u t : [ b , n , d ] − > [ b , n , d ] Dropout:[b,n,d]->[b,n,d] Dropout:[b,n,d]>[b,n,d]
    • A d d N o r m : 2 ∗ [ b , n , d ] ( D r o u p o u t 输出, F F N 输入 ) − > [ b , n , d ] AddNorm:2*[b,n,d](Droupout输出,FFN输入)->[b,n,d] AddNorm:2[b,n,d](Droupout输出,FFN输入)>[b,n,d]
      在这里插入图片描述
      编码器结构
      在这里插入图片描述
      多处执行Dropout

代码

import torch.nn as nn

class Embedding(nn.Module):
    def __init__(self,*args,**kwargs)->None:
        super().__init__(*args,**kwargs)
    def forward(self):
        print(self.__class__.__name__)
class PositionalEncoding(nn.Module):
    def __init__(self,*args,**kwargs)->None:
        super().__init__(*args,**kwargs)
    def forward(self):
        print(self.__class__.__name__)
class MultiheadAttention(nn.Module):
    def __init__(self,*args,**kwargs)->None:
        super().__init__(*args,**kwargs)
    def forward(self):
        print(self.__class__.__name__)
class Dropout(nn.Module):
    def __init__(self,*args,**kwargs)->None:
        super().__init__(*args,**kwargs)
    def forward(self):
        print(self.__class__.__name__)
class AddNorm(nn.Module):
    def __init__(self,*args,**kwargs)->None:
        super().__init__(*args,**kwargs)
    def forward(self):
        print(self.__class__.__name__)
class FFN(nn.Module):
    def __init__(self,*args,**kwargs)->None:
        super().__init__(*args,**kwargs)
    def forward(self):
        print(self.__class__.__name__)

class EncoderBlock(nn.Module):
    def __init__(self,*args, **kwargs)->None:
        super().__init__(*args,**kwargs)
        self.mha = MultiheadAttention()
        self.dropout1=Dropout()
        self.addnorm1=AddNorm()
        self.ffn=FFN()
        self.dropout2=Dropout()
        self.addnorm2 = AddNorm()
    def forward(self):
        self.mha()
        self.dropout1()
        self.addnorm1()
        self.ffn()
        self.dropout2()
        self.addnorm2()

class TransformerEncoder(nn.Module):
    def __init__(self,*args,**kwargs)->None:
        super().__init__(*args,**kwargs)
        self.embedding=Embedding() # 把序号转变为有语义信息的编码
        self.posenc=PositionalEncoding()
        self.dropout=Dropout()
        self.encblocks=nn.Sequential()
        for i in range(3):
            self.encblocks.add_module(str(i),EncoderBlock())
    def forward(self):
        self.embedding()
        self.posenc()
        self.dropout()
        for i,blk in enumerate(self.encblocks):
            print(i)
            blk()

te=TransformerEncoder()
te()

输出结果

Embedding
PositionalEncoding
Dropout
0
MultiheadAttention
Dropout
AddNorm
FFN
Dropout
AddNorm
1
MultiheadAttention
Dropout
AddNorm
FFN
Dropout
AddNorm
2
MultiheadAttention
Dropout
AddNorm
FFN
Dropout
AddNorm


网站公告

今日签到

点亮在社区的每一天
去签到