PyTorch嵌入层(nn.Embedding)

发布于:2025-04-05 ⋅ 阅读:(21) ⋅ 点赞:(0)

在 PyTorch 中,nn.Embedding 层(即 model.user_embedding)除了 .weight 这个核心属性外,还有其他属性和方法。以下是完整的解析:


1. 主要属性

(1) weight(核心参数)
  • 作用:存储所有嵌入向量的可训练权重矩阵。
  • 形状(num_embeddings, embedding_dim)
  • 示例
    print(model.user_embedding.weight.shape)  # 输出:torch.Size([3, 4])
    
(2) num_embeddings
  • 作用:返回嵌入向量的总数(即用户/物品的数量)。
  • 示例
    print(model.user_embedding.num_embeddings)  # 输出:3
    
(3) embedding_dim
  • 作用:返回每个嵌入向量的维度。
  • 示例
    print(model.user_embedding.embedding_dim)  # 输出:4
    
(4) padding_idx(可选)
  • 作用:如果设置了 padding_idx,则对应的嵌入向量会被强制设为 0 且不参与训练。
  • 示例
    # 初始化时设置 padding_idx=0
    self.user_embedding = nn.Embedding(3, 4, padding_idx=0)
    print(model.user_embedding.padding_idx)  # 输出:0
    print(model.user_embedding.weight[0])    # 输出:tensor([0., 0., 0., 0.], grad_fn=<SelectBackward>)
    

2. 主要方法

(1) forward(input)
  • 作用:根据输入的 ID 返回对应的嵌入向量。
  • 示例
    input_ids = torch.tensor([0, 1, 2])  # 查询用户 0、1、2 的向量
    embeddings = model.user_embedding(input_ids)  # 返回 shape (3, 4)
    
(2) reset_parameters()
  • 作用:重新随机初始化权重(通常在训练前调用)。
  • 内部逻辑:默认使用均匀分布 U ( − k , k ) U(-\sqrt{k}, \sqrt{k}) U(k ,k ),其中 k = 1 embedding_dim k = \frac{1}{\text{embedding\_dim}} k=embedding_dim1
  • 示例
    model.user_embedding.reset_parameters()
    
(3) extra_repr()
  • 作用:返回层的额外信息(用于 print 时显示)。
  • 示例
    print(model.user_embedding.extra_repr())  
    # 输出:'num_embeddings=3, embedding_dim=4'
    

3. 其他底层属性(一般无需直接操作)

  • _parameters:存储所有可训练参数(包括 weight)。
  • _buffers:存储非可训练参数(如 BatchNorm 的 running_mean)。
  • training:布尔值,表示是否处于训练模式。

4. 完整属性/方法列表

可以通过 dir() 查看所有属性和方法:

print(dir(model.user_embedding))

输出示例:

['__class__', '__delattr__', '__dir__', ..., 'weight', 'num_embeddings', 'embedding_dim', 'padding_idx', 'forward', 'reset_parameters']

5. 关键总结

属性/方法 用途 示例值/调用方式
.weight 核心权重矩阵 shape=(3, 4)
.num_embeddings 嵌入向量的总数(用户数) 3
.embedding_dim 每个向量的维度 4
.padding_idx 指定填充索引(可选) None0
.forward(input) 查询嵌入向量 model.user_embedding([0, 1])
.reset_parameters() 重新初始化权重 model.user_embedding.reset_parameters()

6. 常见问题

Q:如何修改嵌入向量?
  • 直接操作 .weight
    # 将用户 0 的向量置零
    model.user_embedding.weight.data[0] = torch.zeros(4)
    
Q:如何冻结嵌入层?
  • 禁用梯度:
    model.user_embedding.weight.requires_grad = False
    
Q:padding_idx 和普通索引有什么区别?
  • padding_idx 对应的向量会固定为 0,且不参与梯度更新。

掌握这些属性和方法后,你可以更灵活地操作嵌入层! 🚀