AI学习记录 - torch 的 matmul和dot的关联,也就是点乘和点积的联系

发布于:2024-08-14 ⋅ 阅读:(66) ⋅ 点赞:(0)

有用大佬们点点赞

1、两个一维向量点积 ,求 词A 与 词A 之间的关联度

在这里插入图片描述

2、两个词向量之间求关联度,求 :

词A 与 词A 的关联度 =5
词A 与 词B 的关联度 = 11
词B 与 词A 的关联度= 11
词B 与 词B 的关联度= 25
在这里插入图片描述
刚刚好和矩阵乘法符合:
在这里插入图片描述

3、什么是矩阵乘法,举个例子

在这里插入图片描述

重点结论:矩阵乘法就是 求 词向量 与 词向量转置 之间的点积的集合,为什么是转置,注意要是把上面第二个矩阵当词向量的话,[5, 7]和[6, 8]分别是一个词向量。而不是 [5, 6]和[7, 8]。
3、如果是三维矩阵的点乘呢

在这里插入图片描述

重点结论:三维矩阵点乘就变成单独的二维矩阵进行点乘,所以点乘最多体现在二维矩阵,二维矩阵单独求点乘,求完再合并。
最终结论:点积和点乘没有关联性,但是点乘在特定场景下可以实现批量点积,有助于我们利用点乘的特性来批量求词与词之间的点积(关联性),在自注意力的时候可以使用。

试验代码:

一维向量点乘这样写会报错

import torch

# 定义两个二维矩阵
A = torch.tensor([[1, 2]])
B = torch.tensor([[1, 2]])

# 使用 matmul 计算展平向量的点积
dot_product = torch.matmul(A, B)

print(dot_product)

在这里插入图片描述

正确一维向量点乘

import torch

# 定义两个二维矩阵
A = torch.tensor([[1, 2]])
B = torch.tensor([[1], [2]])

# 使用 matmul 计算展平向量的点积
dot_product = torch.matmul(A, B)

print(dot_product)

在这里插入图片描述

二维矩阵点乘

import torch

# 定义两个二维矩阵
A = torch.tensor([[1, 2], 
                  [3, 4]])
B = torch.tensor([[1, 3], 
                  [2, 4]])

# 使用 matmul 计算展平向量的点积
dot_product = torch.matmul(A, B)

print(dot_product)

在这里插入图片描述

三维矩阵点乘

import torch

# 定义两个二维矩阵
A = torch.tensor([
                    [
                        [1, 2], 
                        [3, 4]
                    ],
                    [   
                        [1, 2], 
                        [3, 4]
                    ]
                ])
B = torch.tensor([
                    [
                        [1, 3], 
                        [2, 4]
                    ],
                    [   
                        [1, 3], 
                        [2, 4]
                    ]
                ])

# 使用 matmul 计算展平向量的点积
dot_product = torch.matmul(A, B)
print(dot_product)

在这里插入图片描述

点积应该这样写

import torch

vector_a = torch.tensor([1, 2, 3])
vector_b = torch.tensor([4, 5, 6])

# 计算点积
dot_product = torch.dot(vector_a, vector_b)

print(f"向量 A: {vector_a}")
print(f"向量 B: {vector_b}")
print(f"A 和 B 的点积: {dot_product}")

下面会报错,二维数组不可以点积

import torch

vector_a = torch.tensor([[1, 2, 3],[1, 2, 3]])
vector_b = torch.tensor([[1, 2, 3],[1, 2, 3]])

# 计算点积
dot_product = torch.dot(vector_a, vector_b)

print(f"向量 A: {vector_a}")
print(f"向量 B: {vector_b}")
print(f"A 和 B 的点积: {dot_product}")

所以点积和点乘不是一个东西,只是点乘在某些场景下可以代表批量点积而已。