在使用 PyTorch 编写神经网络或处理张量数据时,我们常常会接触各种“矩阵乘法”操作,例如 torch.dot
、*
、mm
、bmm
、@
、matmul
等。虽然它们看起来都像“乘法”,但实际作用差异巨大!
本文将带你一次性理清这些操作的 适用场景、维度要求和本质区别,并配合代码示例,彻底掌握它们!
一图总览:矩阵乘法函数差异
操作符 / 函数 | 适用维度 | 操作类型 | 应用场景 |
---|---|---|---|
* |
任意相同形状 | 按元素相乘 | 点积 / Hadamard 积 |
dot |
1D × 1D (向量) | 标量点积 | 向量内积 |
mm |
2D × 2D | 矩阵乘法 | 全连接层、线性代数运算 |
bmm |
3D × 3D (batch) | 批量矩阵乘法 | Transformer、RNN 等 batch 乘法 |
@ |
支持广播 | 通用乘法符号 | 等价于 matmul |
matmul |
≥1D | 广播矩阵乘法 | 最推荐的通用矩阵乘法接口 |
*
:按元素相乘(Hadamard 乘积)
import torch
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[10, 20], [30, 40]])
print(a * b)
# 输出:
# tensor([[ 10, 40],
# [ 90, 160]])
* 是逐元素相乘,形状必须一致或可广播,非线性代数中的矩阵乘法!
dot
:1D 向量点积(内积)
a = torch.tensor([1, 2, 3])
b = torch.tensor([10, 20, 30])
print(torch.dot(a, b)) # 1*10 + 2*20 + 3*30 = 140
# 输出:tensor(140)
torch.dot 只能用于两个一维向量,输出为标量。如果是矩阵或更高维,请使用 matmul。
mm
:二维矩阵乘法(不支持 batch)
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
print(torch.mm(a, b))
# 输出:
# tensor([[19, 22],
# [43, 50]])
bmm
:三维批量矩阵乘法(batch matrix multiplication)
a = torch.randn(4, 2, 3) # 4个样本,每个是 2×3 的矩阵
b = torch.randn(4, 3, 5)
out = torch.bmm(a, b)
print(out.shape) # 输出: torch.Size([4, 2, 5])
bmm 要求输入是形状 [B, N, M] 和 [B, M, P] 的张量,对每个 batch 执行一次普通矩阵乘法。
@
:通用矩阵乘法符号(等价于 matmul
)
a = torch.randn(3, 4)
b = torch.randn(4, 2)
print(a @ b)
@ 是 PyTorch 中矩阵乘法的语法糖,功能等同于 torch.matmul(a, b),推荐日常使用。
✅ matmul
:最推荐的通用矩阵乘法接口
# 向量 × 向量
a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
print(torch.matmul(a, b)) # 输出:11
# 矩阵 × 向量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([10, 20])
print(torch.matmul(a, b)) # 输出:tensor([ 50, 110])
# 批量矩阵乘法
a = torch.randn(8, 2, 3)
b = torch.randn(8, 3, 4)
print(torch.matmul(a, b).shape) # 输出: torch.Size([8, 2, 4])
torch.matmul 支持从 1D 到 nD 的张量自动广播,是最通用、最推荐的矩阵乘法函数。
常见错误警示
a = torch.randn(2, 3)
b = torch.randn(3, 2)
torch.dot(a, b) # ❌ 错误!dot 只支持 1D
torch.mm(a.unsqueeze(0), b) # ❌ 错误!mm 不支持 3D
正确做法
# 推荐使用 matmul 支持高维:
print(torch.matmul(a.unsqueeze(0), b)) # 正确
总结建议
需求 | 推荐用法 |
---|---|
向量点积 | torch.dot() |
普通二维矩阵乘法 | torch.mm() 或 @ |
批量矩阵乘法 | torch.bmm() 或 matmul() |
任意维度乘法 | torch.matmul() (推荐) |
按元素相乘 | * |
💬 如果你觉得这篇整理有帮助,欢迎点赞收藏!