注意力机制

发布于:2025-07-11 ⋅ 阅读:(16) ⋅ 点赞:(0)

 第一种注意力机制

# 注意力机制


import torch
import torch.nn as nn
import torch.nn.functional as F

class Attn(nn.Module):
    def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
        """初始化函数中的参数有5个
        query_size代表query的最后一维大小
        key_size代表key的最后一维大小, value_size1代表value的导数第二维大小

        value = (1, value_size1, value_size2)

        value_size2代表value的倒数第一维大小, output_size输出的最后一维大小
        """
        super(Attn, self).__init__()
        # 将以下参数传入类中
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size

        # 初始化注意力机制实现第一步中需要的线性层
        self.attn = nn.Linear(self.query_size + self.key_size, value_size1)

        # 初始化注意力机制实现第三步中需要的线性层
        self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)

    def forward(self, Q, K, V):
        """forward函数的输入参数有三个
        分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的
        张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量
        """

        # 第一步, 按照计算规则进行计算,
        # 我们采用常见的第一种计算规则
        # 将Q,K进行纵轴拼接, 做一次线性变化, 最后使用softmax处理获得结果
        attn_weights = F.softmax(self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)

        print(' Q、K进行softmax后注意力权重长这样子\n', attn_weights)

        # 然后进行第一步的后半部分, 将得到的权重矩阵与V做矩阵乘法计算,
        # 当二者都是三维张量且第一维代表为batch条数时, 则做bmm运算
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)

        # 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法,
        # 需要将Q与第一步的计算结果再进行拼接
        output = torch.cat((Q[0], attn_applied[0]), 1)

        # 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出
        # 因为要保证输出也是3维张量, 因此使用unsqueeze(0)扩展维度
        output = self.attn_combine(output).unsqueeze(0)

        return output, attn_weights


if __name__ == '__main__':
    query_size = 32
    key_size = 32
    value_size1 = 32
    value_size2 = 64
    output_size = 64

    attn = Attn(query_size, key_size, value_size1, value_size2, output_size)

    # 批次,行,列
    Q = torch.randn(1, 1, 32)
    print(' Q长这样子\n', Q)
    K = torch.randn(1, 1, 32)
    print(' K长这样子\n', K)
    V = torch.randn(1, 32, 64)
    print(' V长这样子\n', V.shape)
    print(' V长这样子\n', V)
    print('**************************************************************************************')

    print(Q[0])
    print(K[0])
    print(' Q、K拼接后长这样子\n', torch.cat((Q[0], K[0])))

    print('**************************************************************************************')

    out, attn_weights = attn(Q, K, V)
    print(' 最终输出结果\n', out.shape)
    print(out)
    print(' 注意力权重\n', attn_weights.shape)
    print(attn_weights)




 运行结果

 Q长这样子
 tensor([[[-0.0843,  0.6108, -0.5214,  0.4358,  1.6302, -0.6159,  1.6340,
           0.1276,  1.5854, -0.2922, -0.9621,  0.1989, -0.0558,  1.9234,
          -0.5138, -0.7876,  1.9724, -0.0659,  0.5300,  1.1414,  1.1585,
           0.9155,  0.0557, -0.7387, -0.5724, -0.0478,  0.4301,  1.2947,
          -0.4314, -0.0663,  0.3610,  0.6614]]])
 K长这样子
 tensor([[[-1.6176, -0.2296,  0.0839,  0.1775,  0.3062, -0.2145, -0.6811,
          -1.3397, -0.4235,  0.4637, -1.3447, -0.5441, -0.9798,  0.8265,
          -0.2740,  0.9446, -2.4202,  1.1822,  1.8531,  2.0389, -0.4581,
          -0.7546,  2.1168,  2.1271,  0.3378,  1.4806, -1.2704,  0.0628,
          -1.2798,  0.0615, -0.0730, -0.8597]]])
 V长这样子
 torch.Size([1, 32, 64])
 V长这样子
 tensor([[[ 1.0952,  0.3832,  0.1141,  ..., -1.3869, -0.0160, -1.3580],
         [ 0.3251,  0.3406,  0.1589,  ..., -0.8902,  2.0466, -0.5664],
         [-0.6364, -1.0243,  0.1915,  ...,  0.6893, -0.8892,  0.2788],
         ...,
         [ 0.3980,  1.6673,  0.4893,  ..., -0.7628, -0.0612, -0.2004],
         [ 0.2605,  0.6287, -2.1606,  ..., -0.8923, -0.4310,  1.8570],
         [-0.2593, -1.3517, -0.4209,  ...,  0.7520,  0.6580, -0.9260]]])
**************************************************************************************
tensor([[-0.0843,  0.6108, -0.5214,  0.4358,  1.6302, -0.6159,  1.6340,  0.1276,
          1.5854, -0.2922, -0.9621,  0.1989, -0.0558,  1.9234, -0.5138, -0.7876,
          1.9724, -0.0659,  0.5300,  1.1414,  1.1585,  0.9155,  0.0557, -0.7387,
         -0.5724, -0.0478,  0.4301,  1.2947, -0.4314, -0.0663,  0.3610,  0.6614]])
tensor([[-1.6176, -0.2296,  0.0839,  0.1775,  0.3062, -0.2145, -0.6811, -1.3397,
         -0.4235,  0.4637, -1.3447, -0.5441, -0.9798,  0.8265, -0.2740,  0.9446,
         -2.4202,  1.1822,  1.8531,  2.0389, -0.4581, -0.7546,  2.1168,  2.1271,
          0.3378,  1.4806, -1.2704,  0.0628, -1.2798,  0.0615, -0.0730, -0.8597]])
 Q、K拼接后长这样子
 tensor([[-0.0843,  0.6108, -0.5214,  0.4358,  1.6302, -0.6159,  1.6340,  0.1276,
          1.5854, -0.2922, -0.9621,  0.1989, -0.0558,  1.9234, -0.5138, -0.7876,
          1.9724, -0.0659,  0.5300,  1.1414,  1.1585,  0.9155,  0.0557, -0.7387,
         -0.5724, -0.0478,  0.4301,  1.2947, -0.4314, -0.0663,  0.3610,  0.6614],
        [-1.6176, -0.2296,  0.0839,  0.1775,  0.3062, -0.2145, -0.6811, -1.3397,
         -0.4235,  0.4637, -1.3447, -0.5441, -0.9798,  0.8265, -0.2740,  0.9446,
         -2.4202,  1.1822,  1.8531,  2.0389, -0.4581, -0.7546,  2.1168,  2.1271,
          0.3378,  1.4806, -1.2704,  0.0628, -1.2798,  0.0615, -0.0730, -0.8597]])
**************************************************************************************
 Q、K进行softmax后注意力权重长这样子
 tensor([[0.0566, 0.0220, 0.0399, 0.0829, 0.0233, 0.0175, 0.0365, 0.0313, 0.0234,
         0.0187, 0.0387, 0.0832, 0.0358, 0.0130, 0.0148, 0.0245, 0.0242, 0.0210,
         0.0087, 0.0236, 0.0509, 0.0279, 0.0473, 0.0297, 0.0309, 0.0546, 0.0165,
         0.0305, 0.0257, 0.0106, 0.0164, 0.0193]], grad_fn=<SoftmaxBackward0>)
 最终输出结果
 torch.Size([1, 1, 64])
tensor([[[-1.5883e-01, -9.7250e-02, -1.4577e-01,  6.2508e-02, -3.0917e-01,
          -4.8471e-01,  1.2058e-01, -3.9673e-01,  4.7531e-01,  2.4023e-01,
          -4.5470e-01,  9.8248e-02, -1.7717e-01,  3.3285e-01,  5.4367e-01,
          -1.0387e-01,  1.0913e-01,  1.9735e-01,  3.9441e-01, -4.1193e-01,
           5.5962e-02, -3.7915e-01, -1.1829e-01, -1.2722e-01,  1.2517e-01,
           4.2707e-01,  1.6100e-01, -3.4799e-02, -1.5643e-01, -2.1065e-02,
          -1.9389e-02,  8.9914e-05, -2.5389e-01, -1.1194e-01, -2.6804e-01,
           6.9662e-01, -3.6186e-01,  6.3613e-01,  1.2927e-01, -1.0210e+00,
          -9.3159e-01, -4.4763e-01, -3.8813e-01, -2.8905e-01,  5.0221e-01,
          -2.9630e-01,  1.9712e-01,  3.4796e-01,  2.0145e-01,  2.1066e-01,
           4.6304e-01,  3.5566e-01,  3.7207e-01,  2.1636e-01,  9.2869e-02,
          -3.1811e-01, -4.5739e-01, -4.8703e-01, -4.9259e-02,  3.0813e-01,
          -4.4769e-01,  2.3227e-01,  9.7959e-02, -3.2980e-02]]],
       grad_fn=<UnsqueezeBackward0>)
 注意力权重
 torch.Size([1, 32])
tensor([[0.0566, 0.0220, 0.0399, 0.0829, 0.0233, 0.0175, 0.0365, 0.0313, 0.0234,
         0.0187, 0.0387, 0.0832, 0.0358, 0.0130, 0.0148, 0.0245, 0.0242, 0.0210,
         0.0087, 0.0236, 0.0509, 0.0279, 0.0473, 0.0297, 0.0309, 0.0546, 0.0165,
         0.0305, 0.0257, 0.0106, 0.0164, 0.0193]], grad_fn=<SoftmaxBackward0>)

Process finished with exit code 0

 第二种注意力机制

import torch
import torch.nn as nn
import torch.nn.functional as F


class Attn(nn.Module):
    def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
        """初始化函数中的参数有5个
        query_size代表query的最后一维大小
        key_size代表key的最后一维大小, value_size1代表value的导数第二维大小

        value = (1, value_size1, value_size2)

        value_size2代表value的倒数第一维大小, output_size输出的最后一维大小
        """
        super(Attn, self).__init__()
        # 将以下参数传入类中
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size

        # 初始化注意力机制实现第一步中需要的线性层
        self.attn = nn.Linear(self.query_size + self.key_size, value_size1)

        # 初始化注意力机制实现第三步中需要的线性层
        self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)

    def forward(self, Q, K, V):
        """forward函数的输入参数有三个
        分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的
        张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量
        """

        # 按照公式计算注意力权重
        # 将Q,K进行纵轴拼接
        combined = torch.cat((Q[0], K[0]), 1)
        # 做一次线性变化
        linear_result = self.attn(combined)
        # 使用tanh函数激活
        tanh_result = torch.tanh(linear_result)
        # 进行内部求和
        sum_result = torch.sum(tanh_result, dim=1, keepdim=True)
        # 使用softmax处理获得结果
        attn_weights = F.softmax(sum_result, dim=1)

        print(' Q、K进行softmax后注意力权重长这样子\n', attn_weights)

        # 将得到的权重矩阵与V做张量乘法
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)

        # 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法,
        # 需要将Q与第一步的计算结果再进行拼接
        output = torch.cat((Q[0], attn_applied[0]), 1)

        # 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出
        # 因为要保证输出也是3维张量, 因此使用unsqueeze(0)扩展维度
        output = self.attn_combine(output).unsqueeze(0)

        return output, attn_weights


if __name__ == '__main__':
    query_size = 32
    key_size = 32
    value_size1 = 32
    value_size2 = 64
    output_size = 64

    attn = Attn(query_size, key_size, value_size1, value_size2, output_size)

    # 批次,行,列
    Q = torch.randn(1, 1, 32)
    print(' Q长这样子\n', Q)
    K = torch.randn(1, 1, 32)
    print(' K长这样子\n', K)
    V = torch.randn(1, 32, 64)
    print(' V长这样子\n', V.shape)
    print(' V长这样子\n', V)
    print('**************************************************************************************')

    print(Q[0])
    print(K[0])
    print(' Q、K拼接后长这样子\n', torch.cat((Q[0], K[0])))

    print('**************************************************************************************')

    out, attn_weights = attn(Q, K, V)
    print(' 最终输出结果\n', out.shape)
    print(out)
    print(' 注意力权重\n', attn_weights.shape)
    print(attn_weights)

 第三种注意力机制(缩放点积注意力)

import torch
import torch.nn as nn
import torch.nn.functional as F


class Attn(nn.Module):
    def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
        """初始化函数中的参数有5个
        query_size代表query的最后一维大小
        key_size代表key的最后一维大小, value_size1代表value的导数第二维大小

        value = (1, value_size1, value_size2)

        value_size2代表value的倒数第一维大小, output_size输出的最后一维大小
        """
        super(Attn, self).__init__()
        # 将以下参数传入类中
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size

        # 这里原代码初始化的线性层在新公式中不再使用,可删除相关初始化,不过保留也不影响后续计算逻辑正确性
        # self.attn = nn.Linear(self.query_size + self.key_size, value_size1)
        # self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)

    def forward(self, Q, K, V):
        """forward函数的输入参数有三个
        分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的
        张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量
        """
        # 计算缩放点积注意力权重
        # 转置K,将其形状从(batch_size, seq_length, key_size)变为(batch_size, key_size, seq_length)
        K_transposed = K.transpose(1, 2)
        # 计算Q与K的转置的点积
        dot_product = torch.bmm(Q, K_transposed)
        # 除以缩放系数,这里缩放系数为键向量维度的平方根
        scaling_factor = torch.sqrt(torch.tensor(self.key_size, dtype=torch.float))
        scaled_dot_product = dot_product / scaling_factor
        # 使用softmax处理获得注意力权重
        attn_weights = F.softmax(scaled_dot_product, dim=2)

        print(' Q、K进行softmax后注意力权重长这样子\n', attn_weights)

        # 将得到的权重矩阵与V做张量乘法
        attn_applied = torch.bmm(attn_weights, V)

        return attn_applied, attn_weights

if __name__ == '__main__':
    query_size = 32
    key_size = 32
    value_size1 = 32
    value_size2 = 64
    output_size = 64

    attn = Attn(query_size, key_size, value_size1, value_size2, output_size)

    # 批次,行,列
    Q = torch.randn(1, 1, 32)
    print(' Q长这样子\n', Q)
    K = torch.randn(1, 1, 32)
    print(' K长这样子\n', K)
    V = torch.randn(1, 32, 64)
    print(' V长这样子\n', V.shape)
    print(' V长这样子\n', V)
    print('**************************************************************************************')

    print(Q[0])
    print(K[0])
    print(' Q、K拼接后长这样子\n', torch.cat((Q[0], K[0])))

    print('**************************************************************************************')

    out, attn_weights = attn(Q, K, V)
    print(' 最终输出结果\n', out.shape)
    print(out)
    print(' 注意力权重\n', attn_weights.shape)
    print(attn_weights)


网站公告

今日签到

点亮在社区的每一天
去签到