PyTorch数据选取与索引详解:从入门到高效实践

发布于:2025-07-25 ⋅ 阅读:(19) ⋅ 点赞:(0)

PyTorch数据选取与索引详解:从入门到高效实践

在PyTorch中,从现有张量(Tensor)中高效、准确地选取所需的数据子集是日常操作的核心。无论是准备数据批次、提取模型输出的特定部分,还是根据条件过滤数据,都离不开强大的索引机制。本篇将系统介绍PyTorch中的基础索引、高级索引以及其他高效的数据选取函数。

一、基础索引与切片 (Basic Indexing and Slicing)

基础索引与切片是最直观、最常用的数据选取方式,其语法与Python的NumPy和列表非常相似。这种方式的主要特点是,它返回的是原始数据的 视图(View),意味着不产生新的内存开销,但修改视图会影响原始张量。

1. 单个元素索引

对于一个多维张量,可以通过提供每个维度的索引来精确定位到单个元素。

2. 数据切片

使用冒号 : 可以选取一个范围内的元素。语法为 start:stop:step,其中 start 是起始索引(包含),stop 是结束索引(不包含)。

关键点:

  • 返回视图:切片操作返回的是一个视图,与原始数据共享内存。这是一个非常重要的特性,因为它避免了不必要的数据复制,提升了效率。但反过来,如果修改了切片得到的数据,原始张量中对应的数据也会被修改。
  • 语法灵活:可以省略 start(表示从头开始)、stop(表示到末尾结束)或 step(默认为1)。单独一个 : 表示选取该维度的所有元素。
使用示例
import torch

# 创建一个 2x4 的张量
x = torch.arange(8).reshape(2, 4)
# tensor([[0, 1, 2, 3],
#         [4, 5, 6, 7]])
print("Original Tensor:\n", x)

# 1. 单个元素索引
# 获取第1行、第2列的元素 (从0开始计数)
element = x[1, 2]
print(f"\nElement at (1, 2): {element}") # tensor(6)

# 2. 数据切片
# 获取第0行所有元素
row_0 = x[0, :]
print("\nRow 0:\n", row_0) # tensor([0, 1, 2, 3])

# 获取第1列所有元素
col_1 = x[:, 1]
print("\nColumn 1:\n", col_1) # tensor([1, 5])

# 获取一个子矩阵: 第0行到第1行,第1列到第2列
sub_matrix = x[0:2, 1:3]
print("\nSub-matrix (0:2, 1:3):\n", sub_matrix)
# tensor([[1, 2],
#         [5, 6]])

# 3. 切片是视图的证明
# 修改切片
sub_matrix[0, 0] = 99
print("\nModified sub-matrix:\n", sub_matrix)
print("Original tensor after modification:\n", x) # 原始张量也被修改了
# tensor([[ 0, 99,  2,  3],
#         [ 4,  5,  6,  7]])

二、高级索引 (Advanced Indexing)

当选取规则变得不规则或依赖于条件时,就需要使用高级索引。高级索引使用一个LongTensor(整数张量)或BoolTensor(布尔张量)作为索引。与基础索引最大的不同在于,高级索引总是返回一个 副本(Copy),而不是视图。

1. 使用整数张量索引

通过传入一个包含索引值的张量,可以一次性地从目标张量中选取多个、非连续的元素。

  • 选取多个元素/行/列:提供一个一维LongTensor来指定要选取的索引。
  • 选取特定坐标点:提供多个一维LongTensor,每个对应一个维度,它们共同指定了要选取的元素的坐标。
2. 使用布尔张量索引(掩码索引)

这是基于条件筛选数据的最强大方法。需要创建一个与原张量形状相同(或可广播)的布尔张量,称为 掩码(Mask)

  • 工作原理:函数会选取原张量中所有在掩码中对应位置为True的元素。
  • 输出形状:无论原张量和掩码的形状如何,结果总是被“拉平”为一个 一维张量
使用示例
import torch

x = torch.arange(12).reshape(3, 4)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])
print("Original Tensor:\n", x)

# 1. 整数索引
# 选取第0行和第2行
int_indexed_rows = x[[0, 2]]
print("\nIndexed rows [0, 2]:\n", int_indexed_rows)
# tensor([[ 0,  1,  2,  3],
#         [ 8,  9, 10, 11]])

# 选取 (0, 1), (1, 2), (2, 3) 三个坐标的元素
rows = torch.tensor([0, 1, 2])
cols = torch.tensor([1, 2, 3])
int_indexed_points = x[rows, cols]
print("\nIndexed points (0,1), (1,2), (2,3):", int_indexed_points) # tensor([1, 6, 11])


# 2. 布尔索引 (Masking)
# 创建一个条件,选取所有大于5的元素
mask = x > 5
# tensor([[False, False, False, False],
#         [False, False,  True,  True],
#         [ True,  True,  True,  True]])
print("\nBoolean Mask (x > 5):\n", mask)
bool_indexed = x[mask]
print("Elements > 5:", bool_indexed) # tensor([ 6,  7,  8,  9, 10, 11])

# 高级索引返回的是副本
bool_indexed[0] = 99
print("\nOriginal tensor after modifying copy:", x) # 原始张量未改变

三、高效的数据选取函数

除了使用[]进行索引,PyTorch还提供了一些专用的函数,它们在某些场景下可读性更强或效率更高。

1. torch.gather()

gather是一个非常强大的函数,常用于根据索引从源张量中“收集”数据。其工作方式可以理解为:output[i][j] = input[i][index[i][j]] (当dim=1时)。

  • 参数: gather(input, dim, index)
    • input: 源张量。
    • dim: 在哪个维度上进行收集。
    • index: 一个LongTensor,其形状与最终输出的形状一致,指定了在dim维度上要收集的元素的索引。
  • 关键点: 输出张量的形状与index张量的形状完全相同。
2. torch.masked_select()

此函数的功能与布尔索引(x[mask])完全相同,但以函数形式提供。它根据一个布尔掩码mask从输入张量中选取元素。

  • 参数: masked_select(input, mask)
  • 关键点: 返回一个包含所有被选中元素的一维张量。mask的形状需要能广播到input的形状。
3. torch.take()

此函数将输入张量视为一个一维张量,然后根据提供的索引选取元素。

  • 参数: take(input, index)
  • 关键点: 输出张量的形状与index张量的形状完全相同。它忽略了输入张量的原始形状。
使用示例
import torch

# --- gather 示例 ---
scores = torch.tensor([[0.1, 0.6, 0.3],   # 样本0的分数
                       [0.8, 0.1, 0.1]])  # 样本1的分数
best_class_indices = torch.tensor([[1], [0]]) # 样本0选第1类, 样本1选第0类

# 我们想根据 best_class_indices 收集每个样本最高分
# dim=1 表示我们沿着列(类别)维度进行收集
best_scores = torch.gather(scores, 1, best_class_indices)
print("--- gather example ---")
print("Best scores:\n", best_scores)
# tensor([[0.6000], [0.8000]])

# --- masked_select 示例 ---
x = torch.arange(6).reshape(2, 3)
mask = x % 2 == 1 # 选取奇数
selected = torch.masked_select(x, mask)
print("\n--- masked_select example ---")
print("Selected odd numbers:", selected) # tensor([1, 3, 5])

# --- take 示例 ---
y = torch.arange(8).reshape(2, 4)
# 将 y 视为 [0, 1, 2, 3, 4, 5, 6, 7]
# 选取第0, 3, 7个元素,并按2x1形状排列
taken = torch.take(y, torch.tensor([[0], [3], [7]]))
print("\n--- take example ---")
print("Taken elements:\n", taken)
# tensor([[0], [3], [7]])

网站公告

今日签到

点亮在社区的每一天
去签到