PyTorch数据选取与索引详解:从入门到高效实践
在PyTorch中,从现有张量(Tensor)中高效、准确地选取所需的数据子集是日常操作的核心。无论是准备数据批次、提取模型输出的特定部分,还是根据条件过滤数据,都离不开强大的索引机制。本篇将系统介绍PyTorch中的基础索引、高级索引以及其他高效的数据选取函数。
一、基础索引与切片 (Basic Indexing and Slicing)
基础索引与切片是最直观、最常用的数据选取方式,其语法与Python的NumPy和列表非常相似。这种方式的主要特点是,它返回的是原始数据的 视图(View),意味着不产生新的内存开销,但修改视图会影响原始张量。
1. 单个元素索引
对于一个多维张量,可以通过提供每个维度的索引来精确定位到单个元素。
2. 数据切片
使用冒号 :
可以选取一个范围内的元素。语法为 start:stop:step
,其中 start
是起始索引(包含),stop
是结束索引(不包含)。
关键点:
- 返回视图:切片操作返回的是一个视图,与原始数据共享内存。这是一个非常重要的特性,因为它避免了不必要的数据复制,提升了效率。但反过来,如果修改了切片得到的数据,原始张量中对应的数据也会被修改。
- 语法灵活:可以省略
start
(表示从头开始)、stop
(表示到末尾结束)或step
(默认为1)。单独一个:
表示选取该维度的所有元素。
使用示例
import torch
# 创建一个 2x4 的张量
x = torch.arange(8).reshape(2, 4)
# tensor([[0, 1, 2, 3],
# [4, 5, 6, 7]])
print("Original Tensor:\n", x)
# 1. 单个元素索引
# 获取第1行、第2列的元素 (从0开始计数)
element = x[1, 2]
print(f"\nElement at (1, 2): {element}") # tensor(6)
# 2. 数据切片
# 获取第0行所有元素
row_0 = x[0, :]
print("\nRow 0:\n", row_0) # tensor([0, 1, 2, 3])
# 获取第1列所有元素
col_1 = x[:, 1]
print("\nColumn 1:\n", col_1) # tensor([1, 5])
# 获取一个子矩阵: 第0行到第1行,第1列到第2列
sub_matrix = x[0:2, 1:3]
print("\nSub-matrix (0:2, 1:3):\n", sub_matrix)
# tensor([[1, 2],
# [5, 6]])
# 3. 切片是视图的证明
# 修改切片
sub_matrix[0, 0] = 99
print("\nModified sub-matrix:\n", sub_matrix)
print("Original tensor after modification:\n", x) # 原始张量也被修改了
# tensor([[ 0, 99, 2, 3],
# [ 4, 5, 6, 7]])
二、高级索引 (Advanced Indexing)
当选取规则变得不规则或依赖于条件时,就需要使用高级索引。高级索引使用一个LongTensor
(整数张量)或BoolTensor
(布尔张量)作为索引。与基础索引最大的不同在于,高级索引总是返回一个 副本(Copy),而不是视图。
1. 使用整数张量索引
通过传入一个包含索引值的张量,可以一次性地从目标张量中选取多个、非连续的元素。
- 选取多个元素/行/列:提供一个一维
LongTensor
来指定要选取的索引。 - 选取特定坐标点:提供多个一维
LongTensor
,每个对应一个维度,它们共同指定了要选取的元素的坐标。
2. 使用布尔张量索引(掩码索引)
这是基于条件筛选数据的最强大方法。需要创建一个与原张量形状相同(或可广播)的布尔张量,称为 掩码(Mask)。
- 工作原理:函数会选取原张量中所有在掩码中对应位置为
True
的元素。 - 输出形状:无论原张量和掩码的形状如何,结果总是被“拉平”为一个 一维张量。
使用示例
import torch
x = torch.arange(12).reshape(3, 4)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
print("Original Tensor:\n", x)
# 1. 整数索引
# 选取第0行和第2行
int_indexed_rows = x[[0, 2]]
print("\nIndexed rows [0, 2]:\n", int_indexed_rows)
# tensor([[ 0, 1, 2, 3],
# [ 8, 9, 10, 11]])
# 选取 (0, 1), (1, 2), (2, 3) 三个坐标的元素
rows = torch.tensor([0, 1, 2])
cols = torch.tensor([1, 2, 3])
int_indexed_points = x[rows, cols]
print("\nIndexed points (0,1), (1,2), (2,3):", int_indexed_points) # tensor([1, 6, 11])
# 2. 布尔索引 (Masking)
# 创建一个条件,选取所有大于5的元素
mask = x > 5
# tensor([[False, False, False, False],
# [False, False, True, True],
# [ True, True, True, True]])
print("\nBoolean Mask (x > 5):\n", mask)
bool_indexed = x[mask]
print("Elements > 5:", bool_indexed) # tensor([ 6, 7, 8, 9, 10, 11])
# 高级索引返回的是副本
bool_indexed[0] = 99
print("\nOriginal tensor after modifying copy:", x) # 原始张量未改变
三、高效的数据选取函数
除了使用[]
进行索引,PyTorch还提供了一些专用的函数,它们在某些场景下可读性更强或效率更高。
1. torch.gather()
gather
是一个非常强大的函数,常用于根据索引从源张量中“收集”数据。其工作方式可以理解为:output[i][j] = input[i][index[i][j]]
(当dim=1
时)。
- 参数:
gather(input, dim, index)
input
: 源张量。dim
: 在哪个维度上进行收集。index
: 一个LongTensor
,其形状与最终输出的形状一致,指定了在dim
维度上要收集的元素的索引。
- 关键点: 输出张量的形状与
index
张量的形状完全相同。
2. torch.masked_select()
此函数的功能与布尔索引(x[mask]
)完全相同,但以函数形式提供。它根据一个布尔掩码mask
从输入张量中选取元素。
- 参数:
masked_select(input, mask)
- 关键点: 返回一个包含所有被选中元素的一维张量。
mask
的形状需要能广播到input
的形状。
3. torch.take()
此函数将输入张量视为一个一维张量,然后根据提供的索引选取元素。
- 参数:
take(input, index)
- 关键点: 输出张量的形状与
index
张量的形状完全相同。它忽略了输入张量的原始形状。
使用示例
import torch
# --- gather 示例 ---
scores = torch.tensor([[0.1, 0.6, 0.3], # 样本0的分数
[0.8, 0.1, 0.1]]) # 样本1的分数
best_class_indices = torch.tensor([[1], [0]]) # 样本0选第1类, 样本1选第0类
# 我们想根据 best_class_indices 收集每个样本最高分
# dim=1 表示我们沿着列(类别)维度进行收集
best_scores = torch.gather(scores, 1, best_class_indices)
print("--- gather example ---")
print("Best scores:\n", best_scores)
# tensor([[0.6000], [0.8000]])
# --- masked_select 示例 ---
x = torch.arange(6).reshape(2, 3)
mask = x % 2 == 1 # 选取奇数
selected = torch.masked_select(x, mask)
print("\n--- masked_select example ---")
print("Selected odd numbers:", selected) # tensor([1, 3, 5])
# --- take 示例 ---
y = torch.arange(8).reshape(2, 4)
# 将 y 视为 [0, 1, 2, 3, 4, 5, 6, 7]
# 选取第0, 3, 7个元素,并按2x1形状排列
taken = torch.take(y, torch.tensor([[0], [3], [7]]))
print("\n--- take example ---")
print("Taken elements:\n", taken)
# tensor([[0], [3], [7]])