核心参数与用法
nn.Embedding的核心参数:
num_embeddings:嵌入表的大小(即离散特征的总类别数,如词汇表大小)。
embedding_dim:每个嵌入向量的维度(输出向量的长度)。
padding_idx(可选):指定一个索引,其对应的嵌入向量将始终为 0(用于处理填充符号)。
import torch
import torch.nn as nn
# 定义嵌入层:词汇表大小为10(索引0-9),嵌入维度为3
embedding = nn.Embedding(num_embeddings=10, embedding_dim=3)
# 输入:形状为(batch_size, seq_len)的整数张量(索引必须在[0, num_embeddings-1]范围内)
input_indices = torch.tensor([[1, 3, 5], [2, 4, 6]]) # 批量大小为2,序列长度为3
# 前向传播:获取嵌入向量
output_embeddings = embedding(input_indices)
print("输入形状:", input_indices.shape) # 输出:torch.Size([2, 3])
print("输出形状:", output_embeddings.shape) # 输出:torch.Size([2, 3, 3])(每个索引被映射为3维向量)
print("输出内容:\n", output_embeddings)
输入形状: torch.Size([2, 3])
输出形状: torch.Size([2, 3, 3])
输出内容:
tensor([[[ 0.5095, 0.3979, -1.7759],
[-0.1456, 1.6262, 0.3929],
[ 0.8530, -0.6685, 1.6823]],
[[ 1.0323, -0.0969, -0.6512],
[ 0.2309, -1.5649, 0.7431],
[-0.3285, -0.2512, -0.1028]]], grad_fn=<EmbeddingBackward0>)
Parameter containing:
tensor([[-1.8749, 0.2108, 0.4401],
[ 0.5095, 0.3979, -1.7759],
[ 1.0323, -0.0969, -0.6512],
[-0.1456, 1.6262, 0.3929],
[ 0.2309, -1.5649, 0.7431],
[ 0.8530, -0.6685, 1.6823],
[-0.3285, -0.2512, -0.1028],
[-0.1919, 0.2022, -0.2425],
[-0.7266, 1.3337, -0.7980],
[ 0.0791, -0.7093, 0.2264]], requires_grad=True)