torch.matmul和@区别

发布于:2024-09-19 ⋅ 阅读:(9) ⋅ 点赞:(0)

在 PyTorch 中,torch.matmul@ 操作符都用于执行矩阵乘法,但它们在某些特定情况下有略微不同的行为。下面我们详细介绍两者之间的区别与相同之处。

1. torch.matmul

torch.matmul 是一个通用的矩阵乘法函数,它能够处理不同维度的张量,依据输入张量的维度自动决定执行何种类型的乘法。

torch.matmul 的行为:
  • 如果输入是 1D 向量:执行 向量内积
  • 如果输入是 2D 矩阵:执行 标准的矩阵乘法
  • 如果输入是高于 2D 的张量:执行 批量矩阵乘法,即对批量中的每对矩阵分别执行矩阵乘法。
例子:
import torch

# 1D 张量(向量)的点积
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.matmul(a, b)
print(result)  # 输出: 32 (即 1*4 + 2*5 + 3*6)

# 2D 矩阵的标准矩阵乘法
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.matmul(A, B)
print(result)  # 输出: tensor([[19, 22], [43, 50]])

# 3D 张量的批量矩阵乘法
C = torch.randn(10, 3, 4)  # 形状 (10, 3, 4) 的批量矩阵
D = torch.randn(10, 4, 5)  # 形状 (10, 4, 5) 的批量矩阵
result = torch.matmul(C, D)  # 批量中的每个 (3, 4) 矩阵和 (4, 5) 矩阵相乘,输出形状为 (10, 3, 5)
print(result.shape)  # 输出: torch.Size([10, 3, 5])

2. @ 操作符

@ 操作符是在 Python 3.5 中引入的用于矩阵乘法的快捷符号。它也可以处理 1D 向量、2D 矩阵 以及高维张量,但它的行为与 torch.matmul 是一致的。

@ 操作符的行为:
  • 如果输入是 1D 向量:执行 向量内积
  • 如果输入是 2D 矩阵:执行 矩阵乘法
  • 如果输入是高于 2D 的张量:执行 批量矩阵乘法(与 torch.matmul 相同)。
例子:
import torch

# 1D 向量的点积
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = a @ b
print(result)  # 输出: 32

# 2D 矩阵乘法
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = A @ B
print(result)  # 输出: tensor([[19, 22], [43, 50]])

# 3D 张量的批量矩阵乘法
C = torch.randn(10, 3, 4)
D = torch.randn(10, 4, 5)
result = C @ D
print(result.shape)  # 输出: torch.Size([10, 3, 5])

3. 相同点:

  • 矩阵和向量的乘法torch.matmul@ 对 1D 和 2D 张量的乘法行为是完全相同的。
  • 批量矩阵乘法:对于更高维度的张量,它们都支持批量矩阵乘法,并且表现一致。

4. 不同点:

唯一的区别在于可读性使用场景

  • @ 操作符:作为 语法糖,更简洁直观,在书写矩阵乘法时更具可读性,尤其是在 Python 代码中类似于数学符号。
  • torch.matmul:作为 PyTorch 提供的函数,功能上更明确,适合程序化或需要在函数式编程中灵活调用的场景。

总结:

  • torch.matmul 是 PyTorch 中的通用矩阵乘法函数,适用于从向量到批量矩阵的各种乘法场景。
  • @ 操作符是 Python 的矩阵乘法符号,行为与 torch.matmul 相同,但更简洁。
  • 二者在功能上基本一致,选择哪个取决于代码的风格和偏好。如果你想要书写更简洁的代码,@ 操作符是一个很好的选择;如果你需要在函数或复杂场景中调用矩阵乘法,torch.matmul 更为合适。

网站公告

今日签到

点亮在社区的每一天
去签到