PyTorch中.reshape(), .unsqueeze(), 和.squeeze()详解以及实战示例

发布于:2025-06-25 ⋅ 阅读:(16) ⋅ 点赞:(0)

在 PyTorch 中,.reshape().unsqueeze().squeeze() 是用于张量(Tensor)形状操作的常用函数。它们分别用于改变形状添加维度移除维度,是进行张量维度管理和模型数据预处理的基础工具。


1. .reshape()

功能:重新调整张量的形状(不改变数据内容)

  • 返回一个具有相同数据但不同维度的新张量。
  • .view() 类似,但更灵活(支持非连续内存的 Tensor)。

语法:

tensor.reshape(new_shape)

示例:

import torch

x = torch.arange(12)        # [0, 1, ..., 11]
print(x.shape)              # torch.Size([12])

y = x.reshape(3, 4)
print(y)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

2. .unsqueeze()

功能:在指定维度插入一个大小为1的新维度

  • 常用于增加 batch 维度或 channel 维度。

语法:

tensor.unsqueeze(dim)
  • dim: 插入的位置(维度索引)。

示例:

x = torch.tensor([1.0, 2.0, 3.0])  # shape: [3]

x1 = x.unsqueeze(0)   # 插入在前面 -> shape: [1, 3]
x2 = x.unsqueeze(1)   # 插入在后面 -> shape: [3, 1]

print(x1.shape)  # torch.Size([1, 3])
print(x2.shape)  # torch.Size([3, 1])

3. .squeeze()

功能:移除 所有大小为1的维度

  • 常用于从 [1, 3, 1, 32][3, 32]

语法:

tensor.squeeze()
tensor.squeeze(dim)  # 仅当 dim 的大小为1时才移除

示例:

x = torch.zeros(1, 3, 1, 5)    # shape: [1, 3, 1, 5]

y = x.squeeze()               # 删除所有为1的维度 → shape: [3, 5]
z = x.squeeze(0)              # 删除第0维(1) → shape: [3, 1, 5]
w = x.squeeze(1)              # 尝试删除第1维(3),不成功 → shape: [1, 3, 1, 5]

print(y.shape)  # torch.Size([3, 5])
print(z.shape)  # torch.Size([3, 1, 5])
print(w.shape)  # torch.Size([1, 3, 1, 5])

总结对比:

操作 功能 维度变化 用法示例
.reshape() 改变形状 任意 (12,) → (3,4)
.unsqueeze() 增加维度(size=1) n → n+1 [3] → [1, 3]
.squeeze() 删除维度(size=1) n → n-k [1, 3, 1] → [3]

实战小例子:图像处理中的通道扩展

img = torch.randn(28, 28)         # 单通道灰度图,无 batch 维度
img = img.unsqueeze(0)            # 添加 batch 维度 → [1, 28, 28]
img = img.unsqueeze(0)            # 添加 channel 维度 → [1, 1, 28, 28]

工程实战应用示例

分别以 图像输入、时间序列建模、Transformer 输入/输出处理 为例,展示 .reshape().unsqueeze().squeeze() 的实际使用场景和推荐写法,帮助你理解维度变换在实际建模中的应用。


1. 图像输入处理(CNN 模型)

PyTorch 中 CNN 的输入要求是 [B, C, H, W](批大小、通道、高、宽)

场景

你从图片加载的数据是 [H, W, C],需要变成 [1, C, H, W] 才能输入模型。

示例

import torch

# 假设从 PIL 加载 RGB 图像后转换为 tensor:shape [H, W, C]
img = torch.rand(224, 224, 3)               # shape: [224, 224, 3]

# 1. 转为 [C, H, W]
img = img.permute(2, 0, 1)                  # shape: [3, 224, 224]

# 2. 增加 batch 维度
img = img.unsqueeze(0)                     # shape: [1, 3, 224, 224]

2. 时间序列建模(RNN / LSTM / GRU)

RNN 接收输入形状为 [batch_size, seq_len, input_size]

场景

你有一个单一序列 [10, 5] 表示 10 个时间步,每步特征是 5 维,想输入模型。

示例

seq = torch.rand(10, 5)     # shape: [10, 5] → 单序列无 batch

# 增加 batch 维度(假设 batch size = 1)
seq = seq.unsqueeze(0)     # shape: [1, 10, 5]

RNN 输出:

import torch.nn as nn

rnn = nn.RNN(input_size=5, hidden_size=20, batch_first=True)
output, hn = rnn(seq)      # output: [1, 10, 20]

如果要提取最后时间步输出:

last_step = output[:, -1, :]       # shape: [1, 20]

3. Transformer 输入/输出处理

Transformer 中的输入通常是 [batch_size, seq_len][batch_size, seq_len, d_model]

场景

你有一个句子的 token 向量 [seq_len],需要输入模型(如 BERT)

示例

token_ids = torch.tensor([101, 2009, 2003, 1037, 2307, 2154, 102])  # shape: [7]

# 增加 batch 维度
token_ids = token_ids.unsqueeze(0)   # shape: [1, 7]

# 假设模型输出:[1, 7, 768] 表示每个 token 的 embedding
model_output = torch.rand(1, 7, 768)

# 取 [CLS] 位置的 embedding(第0个 token)
cls_embedding = model_output[:, 0, :]   # shape: [1, 768]

4.模型预测后 squeeze 应用

例如,回归模型输出为 [batch_size, 1],我们想要 [batch_size]

y_pred = torch.tensor([[2.3], [1.5], [0.8]])  # shape: [3, 1]
y_pred = y_pred.squeeze(1)                   # shape: [3]

实用口诀

操作 用途 示例
.unsqueeze(0) 加 batch 维度 [C,H,W] → [1,C,H,W] / [L,F] → [1,L,F]
.unsqueeze(1) 加 channel / time 维度等 [B,F] → [B,1,F]
.squeeze() 去除所有 shape=1 的维度 [1,C,H,W] → [C,H,W]
.reshape(-1) 展平 [B, C, H, W] → [B, C*H*W] for FC layer