【PyTorch】矩阵乘法操作大全(dot、*、mm、bmm、@、matmul)详细讲解与对比

发布于:2025-04-10 ⋅ 阅读:(57) ⋅ 点赞:(0)

在使用 PyTorch 编写神经网络或处理张量数据时,我们常常会接触各种“矩阵乘法”操作,例如 torch.dot*mmbmm@matmul 等。虽然它们看起来都像“乘法”,但实际作用差异巨大!

本文将带你一次性理清这些操作的 适用场景、维度要求和本质区别,并配合代码示例,彻底掌握它们!


一图总览:矩阵乘法函数差异

操作符 / 函数 适用维度 操作类型 应用场景
* 任意相同形状 按元素相乘 点积 / Hadamard 积
dot 1D × 1D (向量) 标量点积 向量内积
mm 2D × 2D 矩阵乘法 全连接层、线性代数运算
bmm 3D × 3D (batch) 批量矩阵乘法 Transformer、RNN 等 batch 乘法
@ 支持广播 通用乘法符号 等价于 matmul
matmul ≥1D 广播矩阵乘法 最推荐的通用矩阵乘法接口

*:按元素相乘(Hadamard 乘积)

import torch

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[10, 20], [30, 40]])
print(a * b)
# 输出:
# tensor([[ 10,  40],
#         [ 90, 160]])

* 是逐元素相乘,形状必须一致或可广播,非线性代数中的矩阵乘法!


dot:1D 向量点积(内积)

a = torch.tensor([1, 2, 3])
b = torch.tensor([10, 20, 30])
print(torch.dot(a, b))  # 1*10 + 2*20 + 3*30 = 140
# 输出:tensor(140)

torch.dot 只能用于两个一维向量,输出为标量。如果是矩阵或更高维,请使用 matmul。


mm:二维矩阵乘法(不支持 batch)

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
print(torch.mm(a, b))
# 输出:
# tensor([[19, 22],
#         [43, 50]])

bmm:三维批量矩阵乘法(batch matrix multiplication)

a = torch.randn(4, 2, 3)  # 4个样本,每个是 2×3 的矩阵
b = torch.randn(4, 3, 5)
out = torch.bmm(a, b)
print(out.shape)  # 输出: torch.Size([4, 2, 5])

bmm 要求输入是形状 [B, N, M] 和 [B, M, P] 的张量,对每个 batch 执行一次普通矩阵乘法。


@:通用矩阵乘法符号(等价于 matmul

a = torch.randn(3, 4)
b = torch.randn(4, 2)
print(a @ b)

@ 是 PyTorch 中矩阵乘法的语法糖,功能等同于 torch.matmul(a, b),推荐日常使用。


matmul:最推荐的通用矩阵乘法接口

# 向量 × 向量
a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
print(torch.matmul(a, b))  # 输出:11

# 矩阵 × 向量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([10, 20])
print(torch.matmul(a, b))  # 输出:tensor([ 50, 110])

# 批量矩阵乘法
a = torch.randn(8, 2, 3)
b = torch.randn(8, 3, 4)
print(torch.matmul(a, b).shape)  # 输出: torch.Size([8, 2, 4])

torch.matmul 支持从 1D 到 nD 的张量自动广播,是最通用、最推荐的矩阵乘法函数。


常见错误警示

a = torch.randn(2, 3)
b = torch.randn(3, 2)
torch.dot(a, b)  # ❌ 错误!dot 只支持 1D
torch.mm(a.unsqueeze(0), b)  # ❌ 错误!mm 不支持 3D

正确做法

# 推荐使用 matmul 支持高维:
print(torch.matmul(a.unsqueeze(0), b))  # 正确

总结建议

需求 推荐用法
向量点积 torch.dot()
普通二维矩阵乘法 torch.mm()@
批量矩阵乘法 torch.bmm()matmul()
任意维度乘法 torch.matmul()(推荐)
按元素相乘 *

💬 如果你觉得这篇整理有帮助,欢迎点赞收藏!


网站公告

今日签到

点亮在社区的每一天
去签到