文章目录
1. 相对位置矩阵2d
在swin-transformer中,我们会计算每个patch之间的相对位置,那么我们看到有一连串的拉伸和相减,直接贴代码:
import torch
import torch.nn as nn
torch.set_printoptions(precision=3, sci_mode=False,threshold=torch.inf)
if __name__ == "__main__":
run_code = 2
x_len = 5
y_len = 5
x_tensor = torch.arange(x_len)
y_tensor = torch.arange(y_len)
x_meshgrid, y_meshgrid = torch.meshgrid(x_tensor, y_tensor)
print(f"x_tensor=\n{x_tensor}")
print(f"y_tensor=\n{y_tensor}")
print(f"x_meshgrid=\n{x_meshgrid}")
print(f"x_meshgrid.shape=\n{x_meshgrid.shape}")
print(f"y_meshgrid.shape=\n{y_meshgrid.shape}")
print(f"y_meshgrid=\n{y_meshgrid}")
stack_meshgrid = torch.stack(torch.meshgrid(x_tensor, y_tensor))
print(f"stack_meshgrid.shape=\n{stack_meshgrid.shape}")
print(f"stack_meshgrid=\n{stack_meshgrid}")
stack_meshgrid_flatten = torch.flatten(stack_meshgrid, 1)
print(f"stack_meshgrid_flatten.shape=\n{stack_meshgrid_flatten.shape}")
print(f"stack_meshgrid_flatten=\n{stack_meshgrid_flatten}")
stack_meshgrid_flatten_1 = stack_meshgrid_flatten[:, None, :]
stack_meshgrid_flatten_2 = stack_meshgrid_flatten[:, :, None]
relative_coords_bias = stack_meshgrid_flatten_2 - stack_meshgrid_flatten_1
print(f"stack_meshgrid_flatten_1=\n{stack_meshgrid_flatten_1}")
print(f"stack_meshgrid_flatten_2=\n{stack_meshgrid_flatten_2}")
print(f"relative_coords_bias=\n{relative_coords_bias}")
relative_coords_bias[0, :, :] += x_len
relative_coords_bias[1, :, :] += y_len
print(f"relative_coords_bias=\n{relative_coords_bias}")
- result:
x_tensor=
tensor([0, 1, 2, 3, 4])
y_tensor=
tensor([0, 1, 2, 3, 4])
x_meshgrid=
tensor([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3],
[4, 4, 4, 4, 4]])
x_meshgrid.shape=
torch.Size([5, 5])
y_meshgrid.shape=
torch.Size([5, 5])
y_meshgrid=
tensor([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]])
stack_meshgrid.shape=
torch.Size([2, 5, 5])
stack_meshgrid=
tensor([[[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3],
[4, 4, 4, 4, 4]],
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]])
stack_meshgrid_flatten.shape=
torch.Size([2, 25])
stack_meshgrid_flatten=
tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
4],
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,
4]])
stack_meshgrid_flatten_1=
tensor([[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4,
4, 4]],
[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2,
3, 4]]])
stack_meshgrid_flatten_2=
tensor([[[0],
[0],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[1],
[2],
[2],
[2],
[2],
[2],
[3],
[3],
[3],
[3],
[3],
[4],
[4],
[4],
[4],
[4]],
[[0],
[1],
[2],
[3],
[4],
[0],
[1],
[2],
[3],
[4],
[0],
[1],
[2],
[3],
[4],
[0],
[1],
[2],
[3],
[4],
[0],
[1],
[2],
[3],
[4]]])
relative_coords_bias=
tensor([[[ 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
-3, -3, -3, -4, -4, -4, -4, -4],
[ 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
-3, -3, -3, -4, -4, -4, -4, -4],
[ 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
-3, -3, -3, -4, -4, -4, -4, -4],
[ 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
-3, -3, -3, -4, -4, -4, -4, -4],
[ 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
-3, -3, -3, -4, -4, -4, -4, -4],
[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2,
-2, -2, -2, -3, -3, -3, -3, -3],
[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2,
-2, -2, -2, -3, -3, -3, -3, -3],
[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2,
-2, -2, -2, -3, -3, -3, -3, -3],
[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2,
-2, -2, -2, -3, -3, -3, -3, -3],
[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2,
-2, -2, -2, -3, -3, -3, -3, -3],
[ 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1,
-1, -1, -1, -2, -2, -2, -2, -2],
[ 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1,
-1, -1, -1, -2, -2, -2, -2, -2],
[ 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1,
-1, -1, -1, -2, -2, -2, -2, -2],
[ 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1,
-1, -1, -1, -2, -2, -2, -2, -2],
[ 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1,
-1, -1, -1, -2, -2, -2, -2, -2],
[ 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, -1, -1, -1, -1, -1],
[ 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, -1, -1, -1, -1, -1],
[ 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, -1, -1, -1, -1, -1],
[ 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, -1, -1, -1, -1, -1],
[ 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, -1, -1, -1, -1, -1],
[ 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0],
[ 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0],
[ 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0],
[ 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0],
[ 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0]],
[[ 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1,
-2, -3, -4, 0, -1, -2, -3, -4],
[ 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0,
-1, -2, -3, 1, 0, -1, -2, -3],
[ 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1,
0, -1, -2, 2, 1, 0, -1, -2],
[ 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2,
1, 0, -1, 3, 2, 1, 0, -1],
[ 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3,
2, 1, 0, 4, 3, 2, 1, 0],
[ 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1,
-2, -3, -4, 0, -1, -2, -3, -4],
[ 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0,
-1, -2, -3, 1, 0, -1, -2, -3],
[ 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1,
0, -1, -2, 2, 1, 0, -1, -2],
[ 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2,
1, 0, -1, 3, 2, 1, 0, -1],
[ 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3,
2, 1, 0, 4, 3, 2, 1, 0],
[ 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1,
-2, -3, -4, 0, -1, -2, -3, -4],
[ 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0,
-1, -2, -3, 1, 0, -1, -2, -3],
[ 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1,
0, -1, -2, 2, 1, 0, -1, -2],
[ 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2,
1, 0, -1, 3, 2, 1, 0, -1],
[ 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3,
2, 1, 0, 4, 3, 2, 1, 0],
[ 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1,
-2, -3, -4, 0, -1, -2, -3, -4],
[ 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0,
-1, -2, -3, 1, 0, -1, -2, -3],
[ 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1,
0, -1, -2, 2, 1, 0, -1, -2],
[ 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2,
1, 0, -1, 3, 2, 1, 0, -1],
[ 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3,
2, 1, 0, 4, 3, 2, 1, 0],
[ 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1,
-2, -3, -4, 0, -1, -2, -3, -4],
[ 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0,
-1, -2, -3, 1, 0, -1, -2, -3],
[ 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1,
0, -1, -2, 2, 1, 0, -1, -2],
[ 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2,
1, 0, -1, 3, 2, 1, 0, -1],
[ 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3,
2, 1, 0, 4, 3, 2, 1, 0]]])
relative_coords_bias=
tensor([[[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
1, 1],
[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
1, 1],
[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
1, 1],
[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
1, 1],
[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
1, 1],
[6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
2, 2],
[6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
2, 2],
[6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
2, 2],
[6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
2, 2],
[6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
2, 2],
[7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
3, 3],
[7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
3, 3],
[7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
3, 3],
[7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
3, 3],
[7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
3, 3],
[8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
4, 4],
[8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
4, 4],
[8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
4, 4],
[8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
4, 4],
[8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
4, 4],
[9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
5, 5],
[9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
5, 5],
[9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
5, 5],
[9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
5, 5],
[9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
5, 5]],
[[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
2, 1],
[6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
3, 2],
[7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
4, 3],
[8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
5, 4],
[9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
6, 5],
[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
2, 1],
[6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
3, 2],
[7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
4, 3],
[8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
5, 4],
[9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
6, 5],
[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
2, 1],
[6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
3, 2],
[7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
4, 3],
[8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
5, 4],
[9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
6, 5],
[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
2, 1],
[6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
3, 2],
[7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
4, 3],
[8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
5, 4],
[9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
6, 5],
[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
2, 1],
[6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
3, 2],
[7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
4, 3],
[8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
5, 4],
[9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
6, 5]]])
2. kron运算
在结果中,我们发现很多重复的值,这就让我联想到kron运算。
- step1:形成子矩阵
- step2: kron
- pytorch
import torch
import torch.nn as nn
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == '__main__':
run_code = 0
height = 5
width = 5
a_vector = torch.arange(width).to(torch.float).reshape(-1, 1)
a_ones = torch.ones(1, width)
a_matrix = a_vector @ a_ones
print(f"a_matrix=\n{a_matrix}")
b_matrix = a_matrix - a_matrix.T
print(f"b_matrix=\n{b_matrix}")
b_matrix_ones = torch.ones_like(b_matrix)
ab_kron = torch.kron(b_matrix,b_matrix_ones)
print(f"ab_kron=\n{ab_kron}")
final_ab = ab_kron+5
print(f"final_ab=\n{final_ab}")
- result:
a_matrix=
tensor([[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]])
b_matrix=
tensor([[ 0., -1., -2., -3., -4.],
[ 1., 0., -1., -2., -3.],
[ 2., 1., 0., -1., -2.],
[ 3., 2., 1., 0., -1.],
[ 4., 3., 2., 1., 0.]])
ab_kron=
tensor([[ 0., 0., 0., 0., 0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
-2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
[ 0., 0., 0., 0., 0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
-2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
[ 0., 0., 0., 0., 0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
-2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
[ 0., 0., 0., 0., 0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
-2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
[ 0., 0., 0., 0., 0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
-2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
[ 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., -1., -1., -1., -1.,
-1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
[ 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., -1., -1., -1., -1.,
-1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
[ 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., -1., -1., -1., -1.,
-1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
[ 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., -1., -1., -1., -1.,
-1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
[ 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., -1., -1., -1., -1.,
-1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
[ 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
[ 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
[ 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
[ 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
[ 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
[ 3., 3., 3., 3., 3., 2., 2., 2., 2., 2., 1., 1., 1., 1.,
1., 0., 0., 0., 0., 0., -1., -1., -1., -1., -1.],
[ 3., 3., 3., 3., 3., 2., 2., 2., 2., 2., 1., 1., 1., 1.,
1., 0., 0., 0., 0., 0., -1., -1., -1., -1., -1.],
[ 3., 3., 3., 3., 3., 2., 2., 2., 2., 2., 1., 1., 1., 1.,
1., 0., 0., 0., 0., 0., -1., -1., -1., -1., -1.],
[ 3., 3., 3., 3., 3., 2., 2., 2., 2., 2., 1., 1., 1., 1.,
1., 0., 0., 0., 0., 0., -1., -1., -1., -1., -1.],
[ 3., 3., 3., 3., 3., 2., 2., 2., 2., 2., 1., 1., 1., 1.,
1., 0., 0., 0., 0., 0., -1., -1., -1., -1., -1.],
[ 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2., 2.,
2., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[ 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2., 2.,
2., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[ 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2., 2.,
2., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[ 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2., 2.,
2., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[ 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2., 2.,
2., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]])
final_ab=
tensor([[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
2., 2., 1., 1., 1., 1., 1.],
[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
2., 2., 1., 1., 1., 1., 1.],
[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
2., 2., 1., 1., 1., 1., 1.],
[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
2., 2., 1., 1., 1., 1., 1.],
[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
2., 2., 1., 1., 1., 1., 1.],
[6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
3., 3., 2., 2., 2., 2., 2.],
[6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
3., 3., 2., 2., 2., 2., 2.],
[6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
3., 3., 2., 2., 2., 2., 2.],
[6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
3., 3., 2., 2., 2., 2., 2.],
[6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
3., 3., 2., 2., 2., 2., 2.],
[7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
4., 4., 3., 3., 3., 3., 3.],
[7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
4., 4., 3., 3., 3., 3., 3.],
[7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
4., 4., 3., 3., 3., 3., 3.],
[7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
4., 4., 3., 3., 3., 3., 3.],
[7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
4., 4., 3., 3., 3., 3., 3.],
[8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
5., 5., 4., 4., 4., 4., 4.],
[8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
5., 5., 4., 4., 4., 4., 4.],
[8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
5., 5., 4., 4., 4., 4., 4.],
[8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
5., 5., 4., 4., 4., 4., 4.],
[8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
5., 5., 4., 4., 4., 4., 4.],
[9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
6., 6., 5., 5., 5., 5., 5.],
[9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
6., 6., 5., 5., 5., 5., 5.],
[9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
6., 6., 5., 5., 5., 5., 5.],
[9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
6., 6., 5., 5., 5., 5., 5.],
[9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
6., 6., 5., 5., 5., 5., 5.]])