张量拼接操作

发布于:2025-07-14 ⋅ 阅读:(19) ⋅ 点赞:(0)

一.前言

本章节来介绍一下张量拼接的操作,掌握torch.cat torch.stack使⽤,张量的拼接操作在神经⽹络搭建过程中是⾮常常⽤的⽅法,例如: 在后⾯将要学习到的残差⽹络、注意⼒机 制中都使⽤到了张量拼接。

二.torch.cat 函数的使用

torch.cat 函数可以将两个张量根据指定的维度拼接起来.

import torch


def test():
    data1 = torch.randint(0, 10, [3, 5, 4])
    data2 = torch.randint(0, 10, [3, 5, 4])
    print(data1)
    print(data2)
    print('-' * 50)
    # 1. 按0维度拼接
    new_data = torch.cat([data1, data2], dim=0)
    print(new_data.shape)
    print('-' * 50)
    # 2. 按1维度拼接
    new_data = torch.cat([data1, data2], dim=1)
    print(new_data.shape)
    print('-' * 50)
    # 3. 按2维度拼接
    new_data = torch.cat([data1, data2], dim=2)
    print(new_data.shape)

if __name__ == '__main__':
    test()

结果展示:

tensor([[[6, 7, 2, 6],
         [4, 6, 4, 3],
         [5, 3, 4, 9],
         [8, 8, 6, 7],
         [0, 3, 3, 0]],

        [[6, 1, 2, 0],
         [5, 6, 7, 0],
         [6, 4, 8, 0],
         [2, 2, 8, 3],
         [0, 1, 6, 8]],

        [[3, 5, 0, 8],
         [6, 2, 1, 7],
         [8, 9, 9, 8],
         [3, 8, 8, 0],
         [5, 8, 4, 4]]])
tensor([[[7, 2, 2, 1],
         [8, 0, 6, 6],
         [9, 0, 6, 5],
         [1, 3, 7, 7],
         [7, 0, 5, 1]],

        [[0, 7, 3, 1],
         [9, 2, 9, 0],
         [9, 6, 2, 1],
         [9, 3, 5, 0],
         [8, 8, 6, 2]],

        [[1, 8, 9, 9],
         [4, 3, 0, 9],
         [7, 3, 3, 8],
         [2, 4, 6, 9],
         [2, 1, 0, 5]]])
--------------------------------------------------
torch.Size([6, 5, 4])
--------------------------------------------------
torch.Size([3, 10, 4])
--------------------------------------------------
torch.Size([3, 5, 8])

 

三.torch.stack 函数的使用

torch.stack 函数可以将两个张量根据指定的维度叠加起来.

import torch

def test():
    data1 = torch.randint(0, 10, [2, 3])
    data2 = torch.randint(0, 10, [2, 3])
    print(data1)
    print(data2)
    print("="*50)
    new_data = torch.stack([data1, data2], dim=0)
    print(new_data.shape)
    print(new_data)
    print("=" * 50)
    new_data = torch.stack([data1, data2], dim=1)
    print(new_data.shape)
    print(new_data)
    print("=" * 50)
    new_data = torch.stack([data1, data2], dim=2)
    print(new_data.shape)
    print(new_data)

if __name__ == '__main__':
    test()

 结果展示:

tensor([[6, 9, 6],
        [3, 2, 7]])
tensor([[3, 3, 4],
        [9, 1, 4]])
==================================================
torch.Size([2, 2, 3])
tensor([[[6, 9, 6],
         [3, 2, 7]],

        [[3, 3, 4],
         [9, 1, 4]]])
==================================================
torch.Size([2, 2, 3])
tensor([[[6, 9, 6],
         [3, 3, 4]],

        [[3, 2, 7],
         [9, 1, 4]]])
==================================================
torch.Size([2, 3, 2])
tensor([[[6, 3],
         [9, 3],
         [6, 4]],

        [[3, 9],
         [2, 1],
         [7, 4]]])

这里十分的不好理解,大家拷贝完代码自己执行理解一下。

四.总结 

张量的拼接操作也是在后⾯我们经常使⽤⼀种操作。cat 函数可以将张量按照指定的维度拼接起来,stack 函数可以将张量按照指定的维度叠加起来。 

 

 


网站公告

今日签到

点亮在社区的每一天
去签到