导入必要的库
import torch
import torch. nn as nn
import math
from typing import Optional, Tuple
from dataclasses import dataclass
import typing
from transformers. utils import TransformersKwargs
from transformers. modeling_flash_attention_utils import FlashAttentionKwargs
Unpack = typing. Unpack
旋转位置编码辅助函数
def rotate_half ( x) :
x1 = x[ . . . , : x. shape[ - 1 ] // 2 ]
x2 = x[ . . . , x. shape[ - 1 ] // 2 : ]
return torch. cat( ( - x2, x1) , dim= - 1 )
def apply_rotary_pos_emb ( q, k, cos, sin, position_ids= None , unsqueeze_dim= 1 ) :
cos = cos. unsqueeze( unsqueeze_dim)
sin = sin. unsqueeze( unsqueeze_dim)
q_embed = ( q * cos) + ( rotate_half( q) * sin)
k_embed = ( k * cos) + ( rotate_half( k) * sin)
return q_embed, k_embed
键值重复函数
def repeat_kv ( hidden_states: torch. Tensor, n_rep: int ) - > torch. Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states. shape
if n_rep == 1 :
return hidden_states
hidden_states = hidden_states[ : , : , None , : , : ] . expand( batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states. reshape( batch, num_key_value_heads * n_rep, slen, head_dim)
Eager注意力前向传播函数
def eager_attention_forward (
module: nn. Module,
query: torch. Tensor,
key: torch. Tensor,
value: torch. Tensor,
attention_mask: Optional[ torch. Tensor] ,
scaling: float ,
dropout: float = 0.0 ,
** kwargs: Unpack[ TransformersKwargs] ,
) :
key_states = repeat_kv( key, module. num_key_value_groups)
value_states = repeat_kv( value, module. num_key_value_groups)
attn_weights = torch. matmul( query, key_states. transpose( 2 , 3 ) ) * scaling
if attention_mask is not None :
causal_mask = attention_mask[ : , : , : , : key_states. shape[ - 2 ] ]
print ( "causal_mask:" , causal_mask. shape)
attn_weights = attn_weights + causal_mask
attn_weights = nn. functional. softmax( attn_weights, dim= - 1 , dtype= torch. float32) . to( query. dtype)
attn_weights = nn. functional. dropout( attn_weights, p= dropout, training= module. training)
print ( "attn_weights:" , attn_weights. shape)
attn_output = torch. matmul( attn_weights, value_states)
attn_output = attn_output. transpose( 1 , 2 ) . contiguous( )
return attn_output, attn_weights
RoPE位置编码实现
def default_rope_init ( config, device= None ) :
"""默认的RoPE初始化函数"""
dim = config. head_dim if hasattr ( config, 'head_dim' ) else config. hidden_size
inv_freq = 1.0 / (
config. rope_theta ** ( torch. arange( 0 , dim, 2 , dtype= torch. float32) / dim)
)
print ( "inv_freq:" , inv_freq. shape)
return inv_freq. to( device) , 1.0
ROPE_INIT_FUNCTIONS = {
"default" : default_rope_init,
}
class Qwen3MoeRotaryEmbedding ( nn. Module) :
inv_freq: torch. Tensor
def __init__ ( self, config, device= None ) :
super ( ) . __init__( )
if hasattr ( config, "rope_scaling" ) and isinstance ( config. rope_scaling, dict ) :
self. rope_type = config. rope_scaling. get( "rope_type" , config. rope_scaling. get( "type" ) )
else :
self. rope_type = "default"
self. max_seq_len_cached = config. max_position_embeddings
self. original_max_seq_len = config. max_position_embeddings
self. config = config
self. rope_init_fn = ROPE_INIT_FUNCTIONS[ self. rope_type]
inv_freq, self. attention_scaling = self. rope_init_fn( self. config, device)
self. register_buffer( "inv_freq" , inv_freq, persistent= False )
self. original_inv_freq = self. inv_freq
@torch. no_grad ( )
def forward ( self, x, position_ids) :
inv_freq_expanded = self. inv_freq[ None , : , None ] . float ( ) . expand( position_ids. shape[ 0 ] , - 1 , 1 ) . to( x. device)
position_ids_expanded = position_ids[ : , None , : ] . float ( )
device_type = x. device. type if isinstance ( x. device. type , str ) and x. device. type != "mps" else "cpu"
with torch. autocast( device_type= device_type, enabled= False ) :
freqs = ( inv_freq_expanded. float ( ) @ position_ids_expanded. float ( ) ) . transpose( 1 , 2 )
emb = torch. cat( ( freqs, freqs) , dim= - 1 )
cos = emb. cos( ) * self. attention_scaling
sin = emb. sin( ) * self. attention_scaling
return cos. to( dtype= x. dtype) , sin. to( dtype= x. dtype)
Qwen3Moe注意力机制实现
class Qwen3MoeAttention ( nn. Module) :
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__ ( self, config, layer_idx: int ) :
super ( ) . __init__( )
self. config = config
self. layer_idx = layer_idx
self. head_dim = getattr ( config, "head_dim" , config. hidden_size // config. num_attention_heads)
self. num_key_value_groups = config. num_attention_heads // config. num_key_value_heads
self. scaling = self. head_dim** - 0.5
self. attention_dropout = config. attention_dropout
self. is_causal = True
self. q_proj = nn. Linear(
config. hidden_size, config. num_attention_heads * self. head_dim, bias= config. attention_bias
)
self. k_proj = nn. Linear(
config. hidden_size, config. num_key_value_heads * self. head_dim, bias= config. attention_bias
)
self. v_proj = nn. Linear(
config. hidden_size, config. num_key_value_heads * self. head_dim, bias= config. attention_bias
)
self. o_proj = nn. Linear(
config. num_attention_heads * self. head_dim, config. hidden_size, bias= config. attention_bias
)
self. q_norm = Qwen3MoeRMSNorm( self. head_dim, eps= config. rms_norm_eps)
self. k_norm = Qwen3MoeRMSNorm( self. head_dim, eps= config. rms_norm_eps)
self. sliding_window = getattr ( config, "sliding_window" , None )
def forward (
self,
hidden_states: torch. Tensor,
position_embeddings: tuple [ torch. Tensor, torch. Tensor] ,
attention_mask: Optional[ torch. Tensor] ,
past_key_values: Optional = None ,
cache_position: Optional[ torch. LongTensor] = None ,
** kwargs: Unpack[ FlashAttentionKwargs] ,
) - > tuple [ torch. Tensor, Optional[ torch. Tensor] ] :
input_shape = hidden_states. shape[ : - 1 ]
hidden_shape = ( * input_shape, - 1 , self. head_dim)
query_states = self. q_norm( self. q_proj( hidden_states) . view( hidden_shape) ) . transpose( 1 , 2 )
key_states = self. k_norm( self. k_proj( hidden_states) . view( hidden_shape) ) . transpose( 1 , 2 )
value_states = self. v_proj( hidden_states) . view( hidden_shape) . transpose( 1 , 2 )
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin)
print ( "query_states:" , query_states. shape)
if past_key_values is not None :
cache_kwargs = { "sin" : sin, "cos" : cos, "cache_position" : cache_position}
key_states, value_states = past_key_values. update( key_states, value_states, self. layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self. config. _attn_implementation != "eager" :
attention_interface = ALL_ATTENTION_FUNCTIONS[ self. config. _attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout= 0.0 if not self. training else self. attention_dropout,
scaling= self. scaling,
sliding_window= self. sliding_window,
** kwargs,
)
attn_output = attn_output. reshape( * input_shape, - 1 ) . contiguous( )
attn_output = self. o_proj( attn_output)
return attn_output, attn_weights
模拟配置和RMSNorm实现
@dataclass
class MockConfig :
hidden_size: int = 512
num_attention_heads: int = 8
num_key_value_heads: int = 4
head_dim: int = 64
max_position_embeddings: int = 2048
rope_theta: float = 10000.0
rms_norm_eps: float = 1e - 6
attention_bias: bool = False
attention_dropout: float = 0.0
_attn_implementation: str = "eager"
class Qwen3MoeRMSNorm ( nn. Module) :
def __init__ ( self, hidden_size, eps= 1e - 6 ) :
super ( ) . __init__( )
self. weight = nn. Parameter( torch. ones( hidden_size) )
self. variance_epsilon = eps
def forward ( self, hidden_states) :
input_dtype = hidden_states. dtype
hidden_states = hidden_states. to( torch. float32)
variance = hidden_states. pow ( 2 ) . mean( - 1 , keepdim= True )
hidden_states = hidden_states * torch. rsqrt( variance + self. variance_epsilon)
return self. weight * hidden_states. to( input_dtype)
def extra_repr ( self) :
return f" { tuple ( self. weight. shape) } , eps= { self. variance_epsilon} "
主函数:测试代码
if __name__ == "__main__" :
config = MockConfig( )
attention_layer = Qwen3MoeAttention( config, layer_idx= 0 )
rotary_emb = Qwen3MoeRotaryEmbedding( config)
batch_size = 2
seq_length = 8
hidden_size = config. hidden_size
hidden_states = torch. randn( batch_size, seq_length, hidden_size)
position_ids = torch. arange( seq_length) . unsqueeze( 0 ) . expand( batch_size, - 1 )
cos, sin = rotary_emb( hidden_states, position_ids)
print ( f"Position embeddings:" )
print ( f" - cos shape: { cos. shape} " )
print ( f" - sin shape: { sin. shape} " )
attention_mask = torch. tril( torch. ones( batch_size, 1 , seq_length, seq_length) )
attention_mask = ( 1.0 - attention_mask) * torch. finfo( torch. float32) . min
attention_output, attention_weights = attention_layer(
hidden_states= hidden_states,
position_embeddings= ( cos, sin) ,
attention_mask= attention_mask
)
print ( f"\nAttention results:" )
print ( f" - Input shape: { hidden_states. shape} " )
print ( f" - Output shape: { attention_output. shape} " )
print ( f" - Attention weights shape: { attention_weights. shape} " )
inv_freq: torch.Size([32])
Position embeddings:
- cos shape: torch.Size([2, 8, 64])
- sin shape: torch.Size([2, 8, 64])
query_states: torch.Size([2, 8, 8, 64])
causal_mask: torch.Size([2, 1, 8, 8])
attn_weights: torch.Size([2, 8, 8, 8])
Attention results:
- Input shape: torch.Size([2, 8, 512])
- Output shape: torch.Size([2, 8, 512])
- Attention weights shape: torch.Size([2, 8, 8, 8])