🌈 个人主页:十二月的猫-CSDN博客
🔥 系列专栏: 🏀《Python/PyTorch极简课》_十二月的猫的博客-CSDN博客💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光
目录
1. 前言
- 👑《Python/PyTorch极简课》专栏持续更新中,未来最少文章数量为100篇。由于专栏刚刚建立,目前免费,后续将慢慢恢复原价至99.9🍉。
- 👑《Python/PyTorch极简课》专栏主要针对零基础入门的小伙伴。不需要Python基础,不需要深度学习基础,只要你愿意学,这一个专栏将真正让你做到零基础入门。
- 🔥每例项目都包括理论讲解、数据集、源代码。
正在更新中💹💹
🚨项目运行环境:
- 平台:Window11
- 语言环境:Python3.8
- 运行环境1:PyCharm 2021.3
- 运行环境2:Jupyter Notebook 7.3.2
- 框架:PyTorch 2.5.1(CUDA11.8)
2. 标签转独热编码函数
2.1 完整函数
def label2onehot(logits, labels):
"""
将标签转换为 one-hot 编码形式。
参数:
logits (torch.Tensor 或 np.ndarray): 模型的输出 logits,形状通常为 (batch_size, num_classes)。
labels (list 或 torch.Tensor): 对应的标签,形状为 (batch_size,)。
返回:
np.ndarray: 转换后的 one-hot 编码,形状与 logits 相同。
"""
# 创建一个与 logits 形状相同的全零张量
label_onehot = torch.zeros_like(torch.tensor(logits))
# 使用 scatter_ 函数将 labels 转换为 one-hot 编码
label_onehot.scatter_(1, torch.tensor(labels).long().view(-1, 1), 1)
# 将 one-hot 编码的张量转换为 numpy 数组并返回
return label_onehot.numpy()
- 将 logits 转换为 PyTorch 张量,并创建一个与其形状相同的全零张量。
- 将 labels 转换为适合 scatter_ 函数使用的格式(整数类型,二维形状)。
- 使用 scatter_ 函数将 labels 转换为 one-hot 编码。
- 将 one-hot 编码的张量转换为 numpy 数组并返回。
2.2 函数功能解释
1. label_onehot = torch.zeros_like(torch.tensor(logits)):
- 作用: 创建一个与 logits 形状相同的全零张量,用于存储最终的 one-hot 编码。
- torch.tensor(logits): 将输入的 logits 转换为 PyTorch 张量。logits 通常是一个二维数组,形状为 (batch_size, num_classes),表示模型的输出(每个样本对每个类别的得分)。
- torch.zeros_like(...): 创建一个与输入张量形状相同的全零张量。
- 示例:
logits = [[0.1, 0.2, 0.7], [0.9, 0.05, 0.05]] # 形状: (2, 3)
label_onehot = torch.zeros_like(torch.tensor(logits))
# 输出: tensor([[0., 0., 0.],
# [0., 0., 0.]])
模型输出的结果都是二维的:每一行是一个记录;每一列是对一个label的可能性评估
2. torch.tensor(labels).long().view(-1, 1):
- 作用: 将 labels 转换为适合 scatter_ 函数使用的格式。
- torch.tensor(labels): 将输入的 labels 转换为 PyTorch 张量。labels 通常是一个一维数组,形状为 (batch_size,),表示每个样本的真实类别索引。
- .long(): 将张量的数据类型转换为 torch.long(64 位整数类型),因为 scatter_ 函数要求索引必须是整数类型。
- .view(-1, 1): 将 labels 从一维 (batch_size,) 调整为二维 (batch_size, 1),以便与 scatter_ 函数兼容。
- 示例:
labels = [2, 0] # 形状: (2,)
index_tensor = torch.tensor(labels).long().view(-1, 1)
# 输出: tensor([[2],
# [0]])
3. label_onehot.scatter_(1, index_tensor, 1):
- 作用: 使用 scatter_ 函数将 labels 转换为 one-hot 编码。
- scatter_(dim, index, value): 在指定维度 dim 上,根据 index 将 value 填充到目标张量中。
1. dim=1: 表示在第二个维度(列)上进行填充。
2. index: 指定填充的位置(即 labels 转换后的张量)。
3. value=1: 在指定位置填充 1。 - 示例:
label_onehot = torch.zeros(2, 3) # 形状: (2, 3)
index_tensor = torch.tensor([[2], [0]]) # 形状: (2, 1)
label_onehot.scatter_(1, index_tensor, 1)
# 输出: tensor([[0., 0., 1.],
# [1., 0., 0.]])
看起来二维的Tensor在列上是隔开的,但正如线代的矩阵,Tensor在列和行上都是相连的。
4. return label_onehot.numpy():
- 作用: 将 one-hot 编码的 PyTorch 张量转换为 numpy 数组并返回。
- numpy(): 将 PyTorch 张量转换为 numpy 数组,便于与其他库(如 NumPy、SciPy)兼容。
- 示例:
label_onehot = torch.tensor([[0., 0., 1.], [1., 0., 0.]])
onehot_numpy = label_onehot.numpy()
# 输出: array([[0., 0., 1.],
# [1., 0., 0.]], dtype=float32)
3. 实战示例
import torch
# 定义函数
def label2onehot(logits, labels):
label_onehot = torch.zeros_like(torch.tensor(logits))
label_onehot.scatter_(1, torch.tensor(labels).long().view(-1, 1), 1)
return label_onehot.numpy()
# 示例数据
logits = [[0.1, 0.2, 0.7], [0.9, 0.05, 0.05]] # 形状: (2, 3)
labels = [2, 0] # 形状: (2,)
# 调用函数
onehot_labels = label2onehot(logits, labels)
print("One-hot 编码结果:")
print(onehot_labels)
运行结果:
总结:
- 该函数将类别标签(labels)转换为 one-hot 编码形式。
- 使用 scatter_ 函数高效地填充 one-hot 编码。
- 最终返回的是 numpy 数组,便于与其他库兼容。
4. 总结
【如果想学习更多深度学习文章,可以订阅一下热门专栏】
如果想要学习更多pyTorch/python编程的知识,大家可以点个关注并订阅,持续学习、天天进步你的点赞就是我更新的动力,如果觉得对你有帮助,辛苦友友点个赞,收个藏呀~~~