【模型学习之路】TopK池化,全局池化

发布于:2024-11-28 ⋅ 阅读:(14) ⋅ 点赞:(0)

来学学图卷积中的池化操作

目录

DataBatch

Dense Batching

Dynamic Batching

DataBatch

存取操作

TopKPooling

GAP/GMP

一个例子

后话


DataBatch

当进行图级别的任务时,首先的任务是把多个图合成一个batch。

在Transformer中,一个句子的维度是【单词数,词向量长度】。在一个batch内,batch_size个长度相同的句子(长度短了就做padding)的维度是【句子数,单词数,词向量长度】。

这里,在图任务中得到batch有两种策略。

Dense Batching

一个batch有batch_size个图,第i个图的x的特征维度为m_{i}f,那么先:

m = max(m_{1}, m_{2}, ..., m_{batchsize})

把所有的图做padding,然后合到一起,那么最后数据的维度就是【batch_size, m, f】。

这种方式通常用于需要固定大小输入的场景,例如某些图神经网络的实现或者特定的并行计算框架。

Dynamic Batching

这是PyG默认的批处理方式,它不要求所有图具有相同数量的节点。在这种模式下,每个图的节点特征被拼接在一起,形成一个大的特征矩阵【M,f】,其中:

M = \sum_{i=1}^{batchsize}m_{i}

同时,会有一个batch向量,它是一个长度为M的一维Tensor,记录每个节点属于哪个图。

DataBatch

前面提到过,Data对象是PyG数据的基本单元。我们先生成一个一个Data对象的list:

import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader

data_lst = [Data(x=torch.randint(0, 2, (5, 3)), 
                 edge_index=torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]),
                 y=torch.randint(0, 1, (5,)))
                for _ in range(1000)]

重写Dataset,然后将list[Data]转化为Dataset:

class MyDataset(Dataset):
    def __init__(self, data_lst):
        super(MyDataset, self).__init__()
        self.data_lst = data_lst
    
    def __len__(self):
        return len(self.data_lst)
    
    def __getitem__(self, idx):
        return self.data_lst[idx]

dataset = MyDataset(data_lst)
dataset

# output
MyDataset(1000)

进一步做成Dataloader:

dataloader = DataLoader(dataset, batch_size=32, follow_batch=['x'], shuffle=True)
first_batch = list(dataloader)[0]
first_batch

# output
DataBatch(x=[160, 3], x_batch=[160], x_ptr=[33], edge_index=[2, 128], y=[160], batch=[160], ptr=[33])

x,y,edge_index都是由多个图拼接而成。x_batch就是用来记录每个节点属于哪个图。ptr用于记录每个图的位置信息(不用过多关注),大小正好是batch_size + 1,记录每个图的终点和起点。

不指定follow_batch=['x'],就没有了ptr,模型就会认为这是一个由很多图拼起来的一个大图,而不是视为很多图。这里不必深究,指定一下follow_batch就好了。

存取操作

可以继承重写PyG中一些与数据相关的类,做到存取的效果,不过有些难度可以看看这个:【图神经网络工具】PyTorch Geometric Tutorial 之Data Handling - 知乎

也可以看看这个的15~19集:5-数据集创建函数介绍_哔哩哔哩_bilibili

我们实现一个简单的存取方法:

from torch_geometric.data import Batch
batch = Batch.from_data_list(data_lst)
batch

# output
DataBatch(x=[5000, 3], edge_index=[2, 4000], y=[5000], batch=[5000], ptr=[1001])

可以看到,和我们Dataloader取出来的东西一样,都是DataBatch对象。然后我们把它存起来:

torch.save(batch, 'batch.pt')

loaded_batch = torch.load('batch.pt', map_location='cpu', weights_only=False)
data_lst = loaded_batch.to_data_list()

TopKPooling

先端上官方文档:

torch_geometric.nn.pool.TopKPooling — pytorch_geometric documentation

再端上一张网上随便一找就能看到的图:

p是要学习的参数。y的维度是(M, 1),计算出每一个点的“重要性”。除以二范数是为了标准化。

然后选取M个点中k个最重要的

根据这个topk,在X以及A中挑出对应的k个,得到,相应的邻接矩阵也只保留剩下的边之间的关系。

最后,由于y’本身记录了“重要性”的信息,那就把重要性加权到X中:

  

仅发表一下个人意见,出于归一化的想法,感觉用softmax挺合适:

 

好,搞定。

一个小问题,在做这个pool操作时,会不会导致某一个图的所有节点全部消失?

并不会,因为TopK是独立地在每个图中做topk操作。

GAP/GMP

global_mean_pool(GAP)和global_max_pool(GMP)是两种常用的全局池化(global pooling)操作,它们用于将整个图的信息聚合为一个固定大小的向量。

全局平均池化(GAP)操作将图中所有节点的特征向量求平均。简单说来就是,每一个图表示为自己所有节点求平均得到的向量。

全局最大池化(GMP)操作将图中所有节点的特征向量进行逐元素的最大值操作。简单来说就是,对于每一个图,拿出自己所有的节点,拿到每个特征的最大值,组成一个向量。

So,在维度上,都会有这样的特征:【M, f】-> 【batch_size,f】

这俩是两种常用的全局池化操作,它们用于将图中所有节点的特征聚合为一个全局特征向量。这两种操作通常在图神经网络的最后阶段使用,以便将图级别的表示用于图分类或其他下游任务。

一个例子

用PyG写个一个神经网络模型。

import torch
import torch.nn as nn
from torch_geometric.nn import TopKPooling, SAGEConv
from torch_geometric.nn import global_mean_pool as gap
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        torch.manual_seed(114514)
        
        self.conv1 = SAGEConv(128, 128)
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = SAGEConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = SAGEConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)
        
        self.embed = nn.Embedding(100, 128)
        
        self.lin = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1), 
        )
        
        self.bn = nn.BatchNorm1d(128)
        self.bn2 = nn.BatchNorm1d(64)
        
    def forward(self, data):
        """
        x: [M, 1]
        edge_index: [2, e]
        batch: [M]
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = x.squeeze(1)  # [M, 1] -> [M]  # 这里是大坑!在github评论区逛了一圈,还好一个老外和我一样的错误
        x = self.embed(x)  # [M] -> [M, 128]  
        
        x = self.conv1(x, edge_index)  # [M, 128]
        x = F.relu(x)
        x, edge_index, _, batch, *_ = self.pool1(x, edge_index, None, batch)  # [0.8*M, 128]
        
        x1 = gap(x, batch)  # [batch, 128]

        x = self.conv2(x, edge_index)  # [0.8*M, 128]
        x = F.relu(x)
        x, edge_index, _, batch, *_ = self.pool2(x, edge_index, None, batch)  # [0.8*0.8*M, 128]
        
        x2 = gap(x, batch)  # [batch, 128]
        
        x = self.conv3(x, edge_index)  # [0.8*0.8*M, 128]
        x = F.relu(x)
        x, edge_index, _, batch, *_ = self.pool3(x, edge_index, None, batch)  # [0.8*0.8*0.8*M, 128]
        
        x3 = gap(x, batch)  # [batch, 128]
        
        out = x1 + x2 + x3  # [batch, 128]
        out = self.lin(out)  # [batch, 1]
        out = out.squeeze(1)  # [batch]
        out = F.sigmoid(out)
        return out
        
        

这个网络架构的设计意图是利用图卷积层提取局部图结构特征,通过池化层进行降采样以捕捉更全局的信息,然后通过全连接层和激活函数进行特征融合和分类。这种架构在图分类、节点分类等任务中很常见。

后话

代码中的SAGEConv是什么?它是众多卷积方式的一种。

PyG文档上有大量卷积层、池化层的类。确实,路漫漫其修远兮!

这个文章上有很多的卷积层和池化层的讲解,看看能不能在未来的时间里都弄懂它们的原理:转载 | 一文遍览GNN卷积与池化的代表模型 - 知乎