在 PyTorch 中,.reshape()
、.unsqueeze()
和 .squeeze()
是用于张量(Tensor)形状操作的常用函数。它们分别用于改变形状、添加维度和移除维度,是进行张量维度管理和模型数据预处理的基础工具。
1. .reshape()
功能:重新调整张量的形状(不改变数据内容)
- 返回一个具有相同数据但不同维度的新张量。
- 与
.view()
类似,但更灵活(支持非连续内存的 Tensor)。
语法:
tensor.reshape(new_shape)
示例:
import torch
x = torch.arange(12) # [0, 1, ..., 11]
print(x.shape) # torch.Size([12])
y = x.reshape(3, 4)
print(y)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
2. .unsqueeze()
功能:在指定维度插入一个大小为1的新维度
- 常用于增加 batch 维度或 channel 维度。
语法:
tensor.unsqueeze(dim)
dim
: 插入的位置(维度索引)。
示例:
x = torch.tensor([1.0, 2.0, 3.0]) # shape: [3]
x1 = x.unsqueeze(0) # 插入在前面 -> shape: [1, 3]
x2 = x.unsqueeze(1) # 插入在后面 -> shape: [3, 1]
print(x1.shape) # torch.Size([1, 3])
print(x2.shape) # torch.Size([3, 1])
3. .squeeze()
功能:移除 所有大小为1的维度
- 常用于从
[1, 3, 1, 32]
→[3, 32]
语法:
tensor.squeeze()
tensor.squeeze(dim) # 仅当 dim 的大小为1时才移除
示例:
x = torch.zeros(1, 3, 1, 5) # shape: [1, 3, 1, 5]
y = x.squeeze() # 删除所有为1的维度 → shape: [3, 5]
z = x.squeeze(0) # 删除第0维(1) → shape: [3, 1, 5]
w = x.squeeze(1) # 尝试删除第1维(3),不成功 → shape: [1, 3, 1, 5]
print(y.shape) # torch.Size([3, 5])
print(z.shape) # torch.Size([3, 1, 5])
print(w.shape) # torch.Size([1, 3, 1, 5])
总结对比:
操作 | 功能 | 维度变化 | 用法示例 |
---|---|---|---|
.reshape() |
改变形状 | 任意 | (12,) → (3,4) |
.unsqueeze() |
增加维度(size=1) | n → n+1 |
[3] → [1, 3] |
.squeeze() |
删除维度(size=1) | n → n-k |
[1, 3, 1] → [3] |
实战小例子:图像处理中的通道扩展
img = torch.randn(28, 28) # 单通道灰度图,无 batch 维度
img = img.unsqueeze(0) # 添加 batch 维度 → [1, 28, 28]
img = img.unsqueeze(0) # 添加 channel 维度 → [1, 1, 28, 28]
工程实战应用示例
分别以 图像输入、时间序列建模、Transformer 输入/输出处理 为例,展示 .reshape()
、.unsqueeze()
、.squeeze()
的实际使用场景和推荐写法,帮助你理解维度变换在实际建模中的应用。
1. 图像输入处理(CNN 模型)
PyTorch 中 CNN 的输入要求是 [B, C, H, W]
(批大小、通道、高、宽)
场景:
你从图片加载的数据是 [H, W, C]
,需要变成 [1, C, H, W]
才能输入模型。
示例:
import torch
# 假设从 PIL 加载 RGB 图像后转换为 tensor:shape [H, W, C]
img = torch.rand(224, 224, 3) # shape: [224, 224, 3]
# 1. 转为 [C, H, W]
img = img.permute(2, 0, 1) # shape: [3, 224, 224]
# 2. 增加 batch 维度
img = img.unsqueeze(0) # shape: [1, 3, 224, 224]
2. 时间序列建模(RNN / LSTM / GRU)
RNN 接收输入形状为 [batch_size, seq_len, input_size]
场景:
你有一个单一序列 [10, 5]
表示 10 个时间步,每步特征是 5 维,想输入模型。
示例:
seq = torch.rand(10, 5) # shape: [10, 5] → 单序列无 batch
# 增加 batch 维度(假设 batch size = 1)
seq = seq.unsqueeze(0) # shape: [1, 10, 5]
RNN 输出:
import torch.nn as nn
rnn = nn.RNN(input_size=5, hidden_size=20, batch_first=True)
output, hn = rnn(seq) # output: [1, 10, 20]
如果要提取最后时间步输出:
last_step = output[:, -1, :] # shape: [1, 20]
3. Transformer 输入/输出处理
Transformer 中的输入通常是 [batch_size, seq_len]
或 [batch_size, seq_len, d_model]
场景:
你有一个句子的 token 向量 [seq_len]
,需要输入模型(如 BERT)
示例:
token_ids = torch.tensor([101, 2009, 2003, 1037, 2307, 2154, 102]) # shape: [7]
# 增加 batch 维度
token_ids = token_ids.unsqueeze(0) # shape: [1, 7]
# 假设模型输出:[1, 7, 768] 表示每个 token 的 embedding
model_output = torch.rand(1, 7, 768)
# 取 [CLS] 位置的 embedding(第0个 token)
cls_embedding = model_output[:, 0, :] # shape: [1, 768]
4.模型预测后 squeeze 应用
例如,回归模型输出为 [batch_size, 1]
,我们想要 [batch_size]
:
y_pred = torch.tensor([[2.3], [1.5], [0.8]]) # shape: [3, 1]
y_pred = y_pred.squeeze(1) # shape: [3]
实用口诀
操作 | 用途 | 示例 |
---|---|---|
.unsqueeze(0) |
加 batch 维度 | [C,H,W] → [1,C,H,W] / [L,F] → [1,L,F] |
.unsqueeze(1) |
加 channel / time 维度等 | [B,F] → [B,1,F] |
.squeeze() |
去除所有 shape=1 的维度 | [1,C,H,W] → [C,H,W] |
.reshape(-1) |
展平 | [B, C, H, W] → [B, C*H*W] for FC layer |