定义ScaledDotProductAttention结构体:
use candle_core::{Result, Device, Tensor};
use candle_nn::{Linear, Module, linear_no_bias, VarMap, VarBuilder, ops};
struct ScaledDotProductAttention {
wq: Linear,
wk: Linear,
wv: Linear,
d_model: Tensor,
device: Device,
}
为ScaledDotProductAttention结构体实现new方法:
impl ScaledDotProductAttention {
fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {
Ok(Self {
wq: linear_no_bias(embedding_dim, out_dim, vb.pp("wq"))?,
wk: linear_no_bias(embedding_dim, out_dim, vb.pp("wk"))?,
wv: linear_no_bias(embedding_dim, out_dim, vb.pp("wv"))?,
d_model: Tensor::new(embedding_dim as f32, &device)?,
device,
})
}
}
为结构体实现Module的forward trait:
impl Module for ScaledDotProductAttention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let q = self.wq.forward(xs)?;
let k = self.wk.forward(xs)?;
let v = self.wv.forward(xs)?;
let attn_score = q.matmul(&k.t()?)?;
let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;
let dim = attn_score.rank() - 1;
let attn_weights = ops::softmax(&attn_score, dim)?;
let attn_output = attn_weights.matmul(&v)?;
Ok(attn_output)
}
}
融合qkv实现:
定义ScaledDotProductAttentionFusedQKV结构体:
struct ScaledDotProductAttentionFusedQKV {
w_qkv: Linear,
d_model: Tensor,
device: Device,
}
为结构体实现new方法:
impl ScaledDotProductAttentionFusedQKV {
fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {
Ok(Self {
w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?,
d_model: Tensor::new(embedding_dim as f32, &device)?,
device,
})
}
}
为结构体实现forward trait:
impl Module for ScaledDotProductAttentionFusedQKV {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let qkv = self.w_qkv.forward(xs)?;
let (batch_size, seq_len, _) = qkv.dims3()?;
let qkv = qkv.reshape((batch_size, seq_len, 3, ()))?;
let q = qkv.get_on_dim(2, 0)?;
let q = q.reshape((batch_size, seq_len, ()))?;
let k = qkv.get_on_dim(2, 1)?;
let k = k.reshape((batch_size, seq_len, ()))?;
let v = qkv.get_on_dim(2, 2)?;
let v = v.reshape((batch_size, seq_len, ()))?;
let attn_score = q.matmul(&k.t()?)?;
let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;
let dim = attn_score.rank() - 1;
let attn_weights = ops::softmax(&attn_score, dim)?;
let attn_output = attn_weights.matmul(&v)?;
Ok(attn_output)
}
}
测试:
fn main() -> Result<()> {
let device = Device::cuda_if_available(0)?;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89,
0.55, 0.87, 0.66,
0.57, 0.85, 0.64,
0.22, 0.58, 0.33,
0.77, 0.25, 0.10,
0.05, 0.80, 0.55,
0.43, 0.15, 0.89,
0.55, 0.87, 0.66,
0.57, 0.85, 0.64,
0.22, 0.58, 0.33,
0.77, 0.25, 0.10,
0.05, 0.80, 0.55], (2, 6, 3), &device)?;
// let model = ScaledDotProductAttention::new(vb.clone(), 3, 2, device.clone())?;
let model = ScaledDotProductAttentionFusedQKV::new(vb.clone(), 3, 2, device.clone())?;
let output = model.forward(&input)?;
println!("output: {:?}\n", output);
println!("output: {:?}\n", output.to_vec3::<f32>()?);
Ok(())
}