【模块化编程】数据标签 转 独热编码

发布于:2025-03-15 ⋅ 阅读:(18) ⋅ 点赞:(0)

🌈 个人主页:十二月的猫-CSDN博客
🔥 系列专栏: 🏀《Python/PyTorch极简课》_十二月的猫的博客-CSDN博客

💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 

目录

 1. 前言

2. 标签转独热编码函数

2.1 完整函数

2.2 函数功能解释

3. 实战示例

4. 总结


 1. 前言

  • 👑《Python/PyTorch极简课》专栏持续更新中,未来最少文章数量为100篇。由于专栏刚刚建立,目前免费,后续将慢慢恢复原价至99.9🍉。
  • 👑《Python/PyTorch极简课》专栏主要针对零基础入门的小伙伴。不需要Python基础,不需要深度学习基础,只要你愿意学,这一个专栏将真正让你做到零基础入门。
  • 🔥每例项目都包括理论讲解、数据集、源代码

正在更新中💹💹

🚨项目运行环境:

  • 平台:Window11
  • 语言环境:Python3.8
  • 运行环境1:PyCharm 2021.3
  • 运行环境2:Jupyter Notebook 7.3.2
  • 框架:PyTorch 2.5.1(CUDA11.8)

2. 标签转独热编码函数

2.1 完整函数

def label2onehot(logits, labels):
    """
    将标签转换为 one-hot 编码形式。

    参数:
    logits (torch.Tensor 或 np.ndarray): 模型的输出 logits,形状通常为 (batch_size, num_classes)。
    labels (list 或 torch.Tensor): 对应的标签,形状为 (batch_size,)。

    返回:
    np.ndarray: 转换后的 one-hot 编码,形状与 logits 相同。
    """
    # 创建一个与 logits 形状相同的全零张量
    label_onehot = torch.zeros_like(torch.tensor(logits))
    
    # 使用 scatter_ 函数将 labels 转换为 one-hot 编码
    label_onehot.scatter_(1, torch.tensor(labels).long().view(-1, 1), 1)
    
    # 将 one-hot 编码的张量转换为 numpy 数组并返回
    return label_onehot.numpy()
  •  将 logits 转换为 PyTorch 张量,并创建一个与其形状相同的全零张量。
  • 将 labels 转换为适合 scatter_ 函数使用的格式(整数类型,二维形状)。
  • 使用 scatter_ 函数将 labels 转换为 one-hot 编码。
  • 将 one-hot 编码的张量转换为 numpy 数组并返回。  

2.2 函数功能解释

1. label_onehot = torch.zeros_like(torch.tensor(logits)):

  1. 作用: 创建一个与 logits 形状相同的全零张量,用于存储最终的 one-hot 编码。
  2. torch.tensor(logits): 将输入的 logits 转换为 PyTorch 张量。logits 通常是一个二维数组,形状为 (batch_size, num_classes),表示模型的输出(每个样本对每个类别的得分)。
  3. torch.zeros_like(...): 创建一个与输入张量形状相同的全零张量。
  4. 示例:
logits = [[0.1, 0.2, 0.7], [0.9, 0.05, 0.05]]  # 形状: (2, 3)
label_onehot = torch.zeros_like(torch.tensor(logits))
# 输出: tensor([[0., 0., 0.],
#               [0., 0., 0.]])

 模型输出的结果都是二维的:每一行是一个记录;每一列是对一个label的可能性评估

2. torch.tensor(labels).long().view(-1, 1):

  1.  作用: 将 labels 转换为适合 scatter_ 函数使用的格式。
  2. torch.tensor(labels): 将输入的 labels 转换为 PyTorch 张量。labels 通常是一个一维数组,形状为 (batch_size,),表示每个样本的真实类别索引。
  3. .long(): 将张量的数据类型转换为 torch.long(64 位整数类型),因为 scatter_ 函数要求索引必须是整数类型。
  4. .view(-1, 1): 将 labels 从一维 (batch_size,) 调整为二维 (batch_size, 1),以便与 scatter_ 函数兼容。
  5. 示例: 
labels = [2, 0]  # 形状: (2,)
index_tensor = torch.tensor(labels).long().view(-1, 1)
# 输出: tensor([[2],
#               [0]])

3. label_onehot.scatter_(1, index_tensor, 1):

  1. 作用: 使用 scatter_ 函数将 labels 转换为 one-hot 编码
  2. scatter_(dim, index, value): 在指定维度 dim 上,根据 index 将 value 填充到目标张量中。
    1. dim=1: 表示在第二个维度(列)上进行填充。
    2. index: 指定填充的位置(即 labels 转换后的张量)。
    3. value=1: 在指定位置填充 1。
  3. 示例:
label_onehot = torch.zeros(2, 3)  # 形状: (2, 3)
index_tensor = torch.tensor([[2], [0]])  # 形状: (2, 1)
label_onehot.scatter_(1, index_tensor, 1)
# 输出: tensor([[0., 0., 1.],
#               [1., 0., 0.]])

看起来二维的Tensor在列上是隔开的,但正如线代的矩阵,Tensor在列和行上都是相连的。

4. return label_onehot.numpy():

  1. 作用: 将 one-hot 编码的 PyTorch 张量转换为 numpy 数组并返回
  2. numpy(): 将 PyTorch 张量转换为 numpy 数组,便于与其他库(如 NumPy、SciPy)兼容
  3. 示例:
label_onehot = torch.tensor([[0., 0., 1.], [1., 0., 0.]])
onehot_numpy = label_onehot.numpy()
# 输出: array([[0., 0., 1.],
#              [1., 0., 0.]], dtype=float32)

3. 实战示例

import torch

# 定义函数
def label2onehot(logits, labels):
    label_onehot = torch.zeros_like(torch.tensor(logits))
    label_onehot.scatter_(1, torch.tensor(labels).long().view(-1, 1), 1)
    return label_onehot.numpy()

# 示例数据
logits = [[0.1, 0.2, 0.7], [0.9, 0.05, 0.05]]  # 形状: (2, 3)
labels = [2, 0]  # 形状: (2,)

# 调用函数
onehot_labels = label2onehot(logits, labels)
print("One-hot 编码结果:")
print(onehot_labels)

运行结果:

总结:

  1. 该函数将类别标签(labels)转换为 one-hot 编码形式。
  2. 使用 scatter_ 函数高效地填充 one-hot 编码。
  3. 最终返回的是 numpy 数组,便于与其他库兼容。 

4. 总结

【如果想学习更多深度学习文章,可以订阅一下热门专栏】

如果想要学习更多pyTorch/python编程的知识,大家可以点个关注并订阅,持续学习、天天进步你的点赞就是我更新的动力,如果觉得对你有帮助,辛苦友友点个赞,收个藏呀~~~


网站公告

今日签到

点亮在社区的每一天
去签到