一种基于节点选择的异构图迁移学习方法

文档序号:36357389发布日期:2023-12-14 03:54阅读:32来源:国知局
一种基于节点选择的异构图迁移学习方法

本发明涉及迁移学习,具体涉及一种基于节点选择的异构图迁移学习方法。


背景技术:

1、图(网络)作为一种描述现实世界中不同对象间相互作用关系的数据,往往包含丰富的语义和结构信息,利用图神经网络可以对图结构数据进行更充分的信息挖掘。然而,训练图神经网络通常需要大量标记的图数据,收集大规模且高质量的图数据是一件既昂贵又具有挑战性的事情。而图迁移学习可以从含有丰富信息的源网络中学习可迁移的知识,用于提高图神经网络在信息匮乏的目标网络上的性能。但是,源网络中可能存在大量低质量且标注不正确的节点,这会导致源网络和目标网络之间存在较大差异,造成在目标网络上的负迁移。考虑到输入的源网络节点对目标任务的重要程度存在差异,通过使用数据评估的方法准确评估节点的价值,可以选择出高质量的节点,进而提高模型在目标网络上的性能。

2、在实际应用中,许多数据集可能包含低质量或不正确的样本,这些样本的存在可能是由于测量收集过程中的硬件问题或者人为标注过程中的错误造成的。如果源网络中存在低质量或标注不正确的节点,则可能会导致负迁移现象,从而对目标网络的性能产生负面影响。换句话说,由于存在低质量和标注不正确的节点,对源网络的学习可能会带来无用甚至不良的先验知识并损害模型在目标网络上的性能提升。因此,需要采取措施来避免或减轻这种负面影响。另外在某些任务场景下,因为不同的数据集属于不同的领域,所以在训练过程中存在分布不匹配的问题,会造成最终训练好的模型无法在目标网络上取得理想的性能表现。

3、因此在进行图神经网络的迁移学习时,如何在源网络中筛选出与目标网络最相关且具有高质量的节点是一个重要的挑战,为了解决这个问题,需要探索有效的方法来指导源网络中节点的选择,以缩小源网络和目标网络之间的差距,并确保图迁移学习的效果。针对图上的迁移学习,如何对源网络中的节点进行数值量化评估并选择与目标网络更相关且具有高质量的节点,来减小源网络和目标网络之间的差异,指导图迁移学习模型从源网络中学习相关知识,将其迁移到目标网络中,进而提高模型在目标网络上的性能,是亟待解决的问题。


技术实现思路

1、为此,本发明提出一种基于节点选择的异构图迁移学习方法。所述方法利用特定语义的特征提取器聚合基于元路径的邻居信息,用特定语义的分类器对不同语义的特征表示进行分类,同时使用最大均值差异距离和l2正则化来对齐源网络和目标网络的分布,将得到的选择向量加入到各损失函数中,来学习具有标签可分辨性和跨网络一致性的节点嵌入表示,用于对目标网络中的节点进行标签分类预测,所述方法包括三部分,即特征提取、节点标签分类和分层域对齐;具体包括以下步骤:

2、在节点估值器迭代训练中,按照以下步骤进行迭代训练:

3、步骤一、从整个源网络中随机取出数据gb=(vb,eb,ab,xb,yb),其中vb表示节点,eb表示边,ab和xb分别表示网络中节点的邻接矩阵和属性特征矩阵,yb表示网络中节点的标签;

4、步骤二、对每个样本,依据xb或者ab或者[xb,ab]计算每个节点的选择概率wi,则采样器得到基于选择概率wi的选择向量si;

5、步骤三、在分类器迭代训练中,

6、步骤三一、从源网络中随机取出数据对源网络和目标网络中的每个样本,分别学习节点语义级的特征嵌入表示、计算节点的标签分类分数,并聚合不同元路径下的标签分类分数;

7、步骤三二、计算损失函数和

8、步骤三三、更新分类器模型网络的参数θ;

9、步骤四、训练节点估值器模型,更新节点估值器模型网络的参数φ,参数φ表示异构图上的元路径集合;

10、步骤五、更新损失移动平均值δ。

11、进一步地,步骤二中利用下述公式计算每个节点的选择概率wi:

12、

13、式中,hφ表示节点估值函数;分别表示节点和对应的标签。

14、进一步地,步骤三一中利用下述公式学习节点语义级的特征嵌入表示:

15、

16、其中,表示节点vi在元路径φ下的节点嵌入表示;表示针对元路径φ的语义特征提取器;xi表示节点vi的属性特征;表示在元路径φ下与节点vi相连的邻居节点集合。

17、进一步地,步骤三一中利用下述公式计算节点的标签分类分数:

18、

19、其中,表示标签类别分数;clfφ表示分类器。

20、进一步地,步骤三一中按照下述公式聚合不同元路径下的标签分类分数:

21、

22、其中,pi表示聚合后的标签分类分数;attsem表示语义融合器;表示元路径φj对应的标签分类分数;n表示元路径总数;w表示权重矩阵,b表示偏移向量,q表示注意力向量,|v|是节点的数量,tanh是所用的激活函数。

23、进一步地,步骤三二中损失函数通过最小化来学习跨网络具有标签可分辨性的节点嵌入表示;其中,

24、表示源网络在一个元路径φ下的损失函数:

25、

26、表示源网络在语义融合后的损失函数:

27、

28、表示目标网络在语义融合后的损失函数:

29、

30、其中,表示源网络中节点vi在元路径φ下的标签预测结果,表示将多个语义下的标签分类分数融合后得到的节点vi的标签预测结果,表示源网络中的节点vi的真实标签,表示目标网络中节点vj的预测标签;ns表示源网络中的节点数量,nt表示目标网络中的节点数量。

31、进一步地,步骤三二中损失函数为语义内特征对齐损失,通过最小化减小在元路径φ下源网络和目标网络间的域差异:

32、

33、其中,特征映射函数φ(·)用于将节点嵌入表示投影到再生核希尔伯特空间。

34、进一步地,步骤三二中损失函数为语义间标签对齐损失,使用l2正则化来衡量不同元路径下节点的分类概率差异,减小不同分类器对目标网络上节点的分类差异,实现源网络和目标网络的域对齐:

35、

36、其中,表示期望。

37、进一步地,步骤三三中分类器模型网络的参数θ按照以下公式更新:

38、

39、其中,α表示学习率;bc表示样本总数;表示随机采样得到的选择向量,表示对分类器的损失函数求梯度;

40、步骤四中节点估值器模型网络的参数φ按照以下公式更新:

41、

42、其中,β表示学习率;l表示从目标网络提取的节点数量,表示节点估值器损失,fθ表示目标任务分类器模型,即任何带有参数θ的可训练函数;表示从目标网络提取的节点,表示从目标网络提取的节点的标签,表示对φ求梯度,表示根据hφ(gb)选择某个选择向量si的概率。

43、进一步地,步骤五中按照以下公式更新损失移动平均值δ:

44、

45、其中,t表示损失移动平均值窗口大小。

46、本发明的有益技术效果是:

47、针对图数据中的节点选择,考虑到输入的源网络节点对目标任务的重要程度存在差异,通过进行节点选择准确量化节点的价值对于提高模型性能具有很大的潜力,因此本发明提出了一个基于强化学习的图迁移节点选择框架(reinforcement learning-basednode selection framework for graph transfer learning,rlnsf),并将其与异构图上的迁移学习结合,提出了一种基于节点选择的异构图迁移学习方法,该方法利用特定语义的特征提取器和分类器,同时使用mmd(maximum mean discrepancy)-最大均值差异距离和l2正则化来对齐源网络和目标网络的分布,将基于强化学习的图迁移节点选择框架rlnsf得到的选择向量加入到各损失函数中,使用源网络中被选择的节点与目标网络作为异构图跨网络节点分类模型的输入,来学习具有标签类别可分辨性和跨网络一致性的节点嵌入表示。根据在异构引文网络数据集和互联网电影数据集上的大量实验结果可得,本发明所提出的基于节点选择的异构图迁移学习方法相对于选取的基线方法表现更为显著。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1