自注意力机制、多头自注意力机制、填充掩码 Python实现

发布于:2025-04-22 ⋅ 阅读:(86) ⋅ 点赞:(0)

原理讲解

【Transformer系列(2)】注意力机制、自注意力机制、多头注意力机制、通道注意力机制、空间注意力机制超详细讲解

自注意力机制

import torch
import torch.nn as nn


# 自注意力机制
class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)        

    def forward(self, x, mask=None):
        batch_size, seq_len, input_dim = x.shape
        
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        atten_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(input_dim, dtype=torch.float))
        if mask is not None:
            mask = mask.unsqueeze(1)
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        

        atten_scores = torch.softmax(atten_weights, dim=-1)

        attented_values = torch.matmul(atten_scores, v)

        return attented_values

# 自动填充函数
def pad_sequences(sequences, max_len=None):
    batch_size = len(sequences)
    input_dim = sequences[0].shape[-1]
    lengths = torch.tensor([seq.shape[0] for seq in sequences])
    max_len = max_len or lengths.max().item()
    
    padded = torch.zeros(batch_size, max_len, input_dim)
    for i, seq in enumerate(sequences):
        seq_len = seq.shape[0]
        padded[i, :seq_len, :] = seq

    mask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)

    return padded, mask.long()


if __name__ == '__main__':
    batch_size = 2
    seq_len = 3
    input_dim = 128
    seq_len_1 = 3
    seq_len_2 = 5
    
    x1 = torch.randn(seq_len_1, input_dim)            
    x2 = torch.randn(seq_len_2, input_dim)
    
    target_seq_len = 10    
    padded_x, mask = pad_sequences([x1, x2], target_seq_len)

    selfattention = SelfAttention(input_dim)    
    attention = selfattention(padded_x)
    print(attention)

多头自注意力机制

import torch
import torch.nn as nn

# 定义多头自注意力模块
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads

        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)        

    def forward(self, x, mask=None):
        batch_size, seq_len, input_dim = x.shape

        # 将输入向量拆分为多个头
        ## transpose(1,2)后变成 (batch_size, self.num_heads, seq_len, self.head_dim)形式
        q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力权重
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

        # 应用 padding mask
        if mask is not None:
            # mask: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) 用于广播
            mask = mask.unsqueeze(1).unsqueeze(2)  # 扩展维度以便于广播
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        
        
        attn_scores = torch.softmax(attn_weights, dim=-1)

        # 注意力加权求和
        attended_values = torch.matmul(attn_scores, v).transpose(1, 2).contiguous().view(batch_size, seq_len, input_dim)

        return attended_values

# 自动填充函数
def pad_sequences(sequences, max_len=None):
    batch_size = len(sequences)
    input_dim = sequences[0].shape[-1]
    lengths = torch.tensor([seq.shape[0] for seq in sequences])
    max_len = max_len or lengths.max().item()
    
    padded = torch.zeros(batch_size, max_len, input_dim)
    for i, seq in enumerate(sequences):
        seq_len = seq.shape[0]
        padded[i, :seq_len, :] = seq

    mask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)

    return padded, mask.long()

if __name__ == '__main__':
    heads = 2
    batch_size = 2
    seq_len_1 = 3
    seq_len_2 = 5
    input_dim = 128
    
    x1 = torch.randn(seq_len_1, input_dim)            
    x2 = torch.randn(seq_len_2, input_dim)
    
    target_seq_len = 10    
    padded_x, mask = pad_sequences([x1, x2], target_seq_len)
    
    multiheadattention = MultiHeadSelfAttention(input_dim, heads)
    
    attention = multiheadattention(padded_x, mask)    
    print(attention)

网站公告

今日签到

点亮在社区的每一天
去签到