来学学图卷积中的池化操作
目录
DataBatch
当进行图级别的任务时,首先的任务是把多个图合成一个batch。
在Transformer中,一个句子的维度是【单词数,词向量长度】。在一个batch内,batch_size个长度相同的句子(长度短了就做padding)的维度是【句子数,单词数,词向量长度】。
这里,在图任务中得到batch有两种策略。
Dense Batching
一个batch有batch_size个图,第i个图的x的特征维度为【,f】,那么先:
把所有的图做padding,然后合到一起,那么最后数据的维度就是【batch_size, m, f】。
这种方式通常用于需要固定大小输入的场景,例如某些图神经网络的实现或者特定的并行计算框架。
Dynamic Batching
这是PyG默认的批处理方式,它不要求所有图具有相同数量的节点。在这种模式下,每个图的节点特征被拼接在一起,形成一个大的特征矩阵【M,f】,其中:
同时,会有一个batch向量,它是一个长度为M的一维Tensor,记录每个节点属于哪个图。
DataBatch
前面提到过,Data对象是PyG数据的基本单元。我们先生成一个一个Data对象的list:
import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
data_lst = [Data(x=torch.randint(0, 2, (5, 3)),
edge_index=torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]),
y=torch.randint(0, 1, (5,)))
for _ in range(1000)]
重写Dataset,然后将list[Data]转化为Dataset:
class MyDataset(Dataset):
def __init__(self, data_lst):
super(MyDataset, self).__init__()
self.data_lst = data_lst
def __len__(self):
return len(self.data_lst)
def __getitem__(self, idx):
return self.data_lst[idx]
dataset = MyDataset(data_lst)
dataset
# output
MyDataset(1000)
进一步做成Dataloader:
dataloader = DataLoader(dataset, batch_size=32, follow_batch=['x'], shuffle=True)
first_batch = list(dataloader)[0]
first_batch
# output
DataBatch(x=[160, 3], x_batch=[160], x_ptr=[33], edge_index=[2, 128], y=[160], batch=[160], ptr=[33])
x,y,edge_index都是由多个图拼接而成。x_batch就是用来记录每个节点属于哪个图。ptr用于记录每个图的位置信息(不用过多关注),大小正好是batch_size + 1,记录每个图的终点和起点。
不指定follow_batch=['x'],就没有了ptr,模型就会认为这是一个由很多图拼起来的一个大图,而不是视为很多图。这里不必深究,指定一下follow_batch就好了。
存取操作
可以继承重写PyG中一些与数据相关的类,做到存取的效果,不过有些难度可以看看这个:【图神经网络工具】PyTorch Geometric Tutorial 之Data Handling - 知乎
也可以看看这个的15~19集:5-数据集创建函数介绍_哔哩哔哩_bilibili
我们实现一个简单的存取方法:
from torch_geometric.data import Batch
batch = Batch.from_data_list(data_lst)
batch
# output
DataBatch(x=[5000, 3], edge_index=[2, 4000], y=[5000], batch=[5000], ptr=[1001])
可以看到,和我们Dataloader取出来的东西一样,都是DataBatch对象。然后我们把它存起来:
torch.save(batch, 'batch.pt')
loaded_batch = torch.load('batch.pt', map_location='cpu', weights_only=False)
data_lst = loaded_batch.to_data_list()
TopKPooling
先端上官方文档:
torch_geometric.nn.pool.TopKPooling — pytorch_geometric documentation
再端上一张网上随便一找就能看到的图:
p是要学习的参数。y的维度是(M, 1),计算出每一个点的“重要性”。除以二范数是为了标准化。
然后选取M个点中k个最重要的
根据这个topk,在X以及A中挑出对应的k个,得到,相应的邻接矩阵也只保留剩下的边之间的关系。
最后,由于y’本身记录了“重要性”的信息,那就把重要性加权到X中:
仅发表一下个人意见,出于归一化的想法,感觉用softmax挺合适:
好,搞定。
一个小问题,在做这个pool操作时,会不会导致某一个图的所有节点全部消失?
并不会,因为TopK是独立地在每个图中做topk操作。
GAP/GMP
global_mean_pool
(GAP)和global_max_pool
(GMP)是两种常用的全局池化(global pooling)操作,它们用于将整个图的信息聚合为一个固定大小的向量。
全局平均池化(GAP)操作将图中所有节点的特征向量求平均。简单说来就是,每一个图表示为自己所有节点求平均得到的向量。
全局最大池化(GMP)操作将图中所有节点的特征向量进行逐元素的最大值操作。简单来说就是,对于每一个图,拿出自己所有的节点,拿到每个特征的最大值,组成一个向量。
So,在维度上,都会有这样的特征:【M, f】-> 【batch_size,f】
这俩是两种常用的全局池化操作,它们用于将图中所有节点的特征聚合为一个全局特征向量。这两种操作通常在图神经网络的最后阶段使用,以便将图级别的表示用于图分类或其他下游任务。
一个例子
用PyG写个一个神经网络模型。
import torch
import torch.nn as nn
from torch_geometric.nn import TopKPooling, SAGEConv
from torch_geometric.nn import global_mean_pool as gap
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
torch.manual_seed(114514)
self.conv1 = SAGEConv(128, 128)
self.pool1 = TopKPooling(128, ratio=0.8)
self.conv2 = SAGEConv(128, 128)
self.pool2 = TopKPooling(128, ratio=0.8)
self.conv3 = SAGEConv(128, 128)
self.pool3 = TopKPooling(128, ratio=0.8)
self.embed = nn.Embedding(100, 128)
self.lin = nn.Sequential(
nn.Linear(128, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(64, 1),
)
self.bn = nn.BatchNorm1d(128)
self.bn2 = nn.BatchNorm1d(64)
def forward(self, data):
"""
x: [M, 1]
edge_index: [2, e]
batch: [M]
"""
x, edge_index, batch = data.x, data.edge_index, data.batch
x = x.squeeze(1) # [M, 1] -> [M] # 这里是大坑!在github评论区逛了一圈,还好一个老外和我一样的错误
x = self.embed(x) # [M] -> [M, 128]
x = self.conv1(x, edge_index) # [M, 128]
x = F.relu(x)
x, edge_index, _, batch, *_ = self.pool1(x, edge_index, None, batch) # [0.8*M, 128]
x1 = gap(x, batch) # [batch, 128]
x = self.conv2(x, edge_index) # [0.8*M, 128]
x = F.relu(x)
x, edge_index, _, batch, *_ = self.pool2(x, edge_index, None, batch) # [0.8*0.8*M, 128]
x2 = gap(x, batch) # [batch, 128]
x = self.conv3(x, edge_index) # [0.8*0.8*M, 128]
x = F.relu(x)
x, edge_index, _, batch, *_ = self.pool3(x, edge_index, None, batch) # [0.8*0.8*0.8*M, 128]
x3 = gap(x, batch) # [batch, 128]
out = x1 + x2 + x3 # [batch, 128]
out = self.lin(out) # [batch, 1]
out = out.squeeze(1) # [batch]
out = F.sigmoid(out)
return out
这个网络架构的设计意图是利用图卷积层提取局部图结构特征,通过池化层进行降采样以捕捉更全局的信息,然后通过全连接层和激活函数进行特征融合和分类。这种架构在图分类、节点分类等任务中很常见。
后话
代码中的SAGEConv是什么?它是众多卷积方式的一种。
PyG文档上有大量卷积层、池化层的类。确实,路漫漫其修远兮!
这个文章上有很多的卷积层和池化层的讲解,看看能不能在未来的时间里都弄懂它们的原理:转载 | 一文遍览GNN卷积与池化的代表模型 - 知乎