rust-candle学习笔记13-实现多头注意力

发布于:2025-05-11 ⋅ 阅读:(18) ⋅ 点赞:(0)

参考:about-pytorch

定义结构体:

use core::f32;

use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{embedding, linear_no_bias, linear, ops, Dropout, Linear, Module, VarBuilder, VarMap};

struct MultiHeadAttention {
    w_qkv: Linear,
    dropout: Dropout, 
    d_model: Tensor,
    mask: Tensor,
    out_proj: Linear,
    device: Device,
    out_dim: usize,
    num_heads: usize,
    head_dim: usize,
}

定义初始化方法:

impl MultiHeadAttention {
    fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, seq_len: usize, num_heads: usize, drop_p: f32, device: Device) -> Result<Self> {
        if out_dim % num_heads != 0 {
            return Err(candle_core::Error::msg("out_dim must be divisible by num_heads"));
        }
        Ok(Self { w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?, 
                    dropout: Dropout::new(drop_p), 
                    d_model: Tensor::new(embedding_dim as f32, &device)?, 
                    mask: Tensor::tril2(seq_len, DType::U32, &device)?, 
                    out_proj: linear(out_dim, out_dim, vb.pp("out_proj"))?, 
                    device, 
                    out_dim, 
                    num_heads, 
                    head_dim: out_dim / num_heads, 
        })
    }
}

定义forward方法:

fn forward(&self, x: &Tensor, train: bool) -> Result<Tensor> {
        let qkv = self.w_qkv.forward(x)?;
        let (batch_size, seq_len, _) = qkv.dims3()?;
        let qkv = qkv.reshape((batch_size, seq_len, 3, self.num_heads, self.head_dim))?;
        let q = qkv.get_on_dim(2, 0)?;
        // Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        let q = q.transpose(1, 2)?.contiguous()?;
        let k = qkv.get_on_dim(2, 0)?;
        let k = k.transpose(1, 2)?.contiguous()?;
        let v = qkv.get_on_dim(2, 0)?;
        let v = v.transpose(1, 2)?.contiguous()?;
        let attn_scores = q.matmul(&k.transpose(2, 3)?)?;
        let mask = self.mask.broadcast_as(attn_scores.shape())?;
        let attn_scores = masked_fill(&attn_scores, &mask, f32::NEG_INFINITY)?;
        let attn_scores = attn_scores.broadcast_div(&self.d_model.sqrt()?)?;
        let softmax_dim = attn_scores.rank() - 1;
        // let attn_weights = ops::softmax_last_dim(&attn_scores)?;  //如果是cpu,可以用这个
        let attn_weights = ops::softmax(&attn_scores, softmax_dim)?;
        let attn_weights = self.dropout.forward(&attn_weights, train)?;
        let attn_output = attn_weights.matmul(&v)?;
        let attn_output = attn_output.transpose(1, 2)?;
        let attn_output = attn_output.reshape(&[batch_size, seq_len, self.num_heads*self.head_dim])?;
        let attn_output = self.out_proj.forward(&attn_output)?;
        Ok(attn_output)
    }

测试:

fn main() -> Result<()> {
    let device = Device::cuda_if_available(0)?;
    let varmap = VarMap::new();
    let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
    
    let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 
                                                    0.55, 0.87, 0.66,
                                                    0.57, 0.85, 0.64,
                                                    0.22, 0.58, 0.33,
                                                    0.77, 0.25, 0.10,
                                                    0.05, 0.80, 0.55, 
                                                    0.43, 0.15, 0.89, 
                                                    0.55, 0.87, 0.66,
                                                    0.57, 0.85, 0.64,
                                                    0.22, 0.58, 0.33,
                                                    0.77, 0.25, 0.10,
                                                    0.05, 0.80, 0.55], (2, 6, 3), &device)?;
    let model = MultiHeadAttention::new(vb.clone(), 3, 4, 6, 2, 0.1, device.clone())?;
    let output = model.forward(&input, true)?;
    println!("output: {:?}\n", output);
    println!("output: {:?}\n", output.to_vec3::<f32>()?);
    Ok(())
}


网站公告

今日签到

点亮在社区的每一天
去签到