cs224w课程学习笔记-第6课 训练GNN
前言
在前面五节课里我们已经学习了图嵌入的方法,接下来我们结合图学习的任务来端到端的进行模型训练,验证与性能度量.
一、任务类型适配(prediction head)
1、节点任务
这个就很好实现了因为前面的信息提取输出的就是节点嵌入,因此 y p r e = W ( H ) h v ( L ) y_{pre}=W^{(H)}h^{(L)}_v ypre=W(H)hv(L),从结构上来说就是节点嵌入输出可以接一个MLP就可实现目标预测输出
2、边的任务
输出的是节点嵌入,因此我们需要先得到边的嵌入,然后才能得到边的预测,其常见方法有
- 拼接+线性:边涉及两个节点,将两个节点的嵌入拼起来再做线性变换得到预测,可用于多分类预测,其公式如下所示
- 点积:将两个节点的嵌入进行点积,其反馈了节点的相似性,适用于二分类预测,对于多分类问题(例如预测边的类型或权重),可通过引入多个可训练的权重矩阵,将点积方法可以扩展到多分类问题,类似于多头注意力机制。
3、图的任务
还记得前面的课里有提到如何通过节点嵌入得到图的嵌入吗?通过对所有节点嵌入的加和或者构建虚拟节点得到图的嵌入,因此图的预测就可以通过对所有节点的聚合操作,转化得到预测结果.但是简单的对最后输出的节点嵌入做聚合会存在缺陷,如下图,使用加和对最后的节点嵌入聚合,可以看到不同图其聚合的结果一样,因此引入了分层聚合来解决该问题.
分层聚合的核心思想是通过联合训练两个GNN(GNN A 和 GNN B)来实现节点的聚类和特征聚合。
其步骤如下:
1、生成节点嵌入(GNN A):使用GNN A对图中的每个节点进行特征提取,生成节点嵌入 H.
2、生成聚类分配(GNN B):使用GNN B生成节点的聚类分配矩阵 S,其中 S ij表示节点 i 属于聚类 j 的概率。通常使用softmax函数对GNN B的输出进行归一化,以确保每个节点只属于一个聚类。
3、节点聚合:根据聚类分配矩阵 S,对GNN A生成的节点嵌入 H 进行聚合;
H p o o l e d = S T H H_{pooled}=S^TH Hpooled=STH
其中, H p o o l e d H_{pooled} Hpooled是池化后的新节点嵌入。
4、构建池化图:根据原图的邻接矩阵 A 和聚类分配矩阵 S,生成池化图的邻接矩阵:
A p o o l e d = S T A S A_{pooled}=S^TAS Apooled=STAS
其中, A p o o l e d A_{pooled} Apooled为新节点之间的连接关系。
5、联合训练:通过反向传播联合优化GNN A和GNN B的参数,使得池化后的图能够更好地完成下游任务(如分类或回归)。
二、训练
1、标签
分为自监督(标签来自数据本身)与有监督(标签来自外部)的学习.有监督好理解,自监督会稍微麻烦一些.
- 节点预测,其自监督的标签可以是聚类系数,pagerank
- 边预测,其自监督的标签可以隐藏节点之间的边,如何预测该边是否存在
- 图预测,其自监督的标签可以是两个图是否同构
2、损失函数
跟传统机器学习差不多,回归任务常用MSE,分类任务如交叉熵
3、评价系数
跟传统机器学习差不多,回归任务常用MAE,RMSE,分类任务如ACC,准确率,召回,F1,roc曲线等
4、数据集划分
通常选择设置随机种子,然后进行随机划分得到训练集,验证集与测试集;其中比较麻烦的是对图数据划分,在图任务中,一个图就是一个样本点,其样本之间是独立的,互不影响;但是在节点任务中,一个节点是一个数据样本,如果样本来自同一张图,数据之间将不再互相独立,如下图所示.
这种情况下的应用方式有两种
- 直推式:训练时用整张图计算嵌入,用1,2节点标签计算loss;验证时用整张图计算嵌入,用3,4节点标签计算loss;测试时用整张图计算嵌入,用5,6节点标签计算loss;
- 归纳式:切断不同数据集直接的边,使得其互相独立;这样的话,其应用如下图所示
两者的差异在于前者数据集之间不互相独立,不适合用于图的任务,因为无法处理未见过的图.
下面举个例子说明两种方法(以边任务为例,边任务有技巧,更复杂)
下面有三张图分别为训练,验证与测试集,其标签分别隐藏图上部分边来进行训练验证与测试.
接下来看归纳式的划分,在一张图上得到训练集,验证集与测试集;可以看到其隐藏的边是是越来少的,其不同环节的标签数是一致的,如下图,训练集隐藏了边(5,4),(3,4),(2,3),其中训练环节的标签是(3,4),到验证环节时,因为(3,4)在训练集已经出现过了,因此不再隐藏,此时标签为(5,4);再到测试环节,(5,4)在验证环节出现了,不再隐藏,测试标签为(2,3);该种方式符合模型逐渐学习的模型.
三、总结
本课程,描述了在节点嵌入后如何适用于不同任务类型,其中可直接用于节点任务,通过拼接点积等操作可用于边预测,最复杂的是图预测需要做分层的节点聚合才能得到的区分不同图的表示;然后介绍了自监督由数据结构和信息本身得到标签,与有监督的标签来源;随后介绍了图学习里的loss与评价指标,基本与机器学习中常见的方法一致;最后讲述了数据集的划分,分为直推式划分(训练集,验证集,测试集来源于一个图),归纳式(训练集,验证集,测试集来源于不同图),并介绍了不同任务类型的例子增进两种方法的理解,其中以边任务较为特征,值得注意.