本公开涉及计算机,特别涉及一种图神经网络模型的训练方法、装置、电子设备及存储介质。
背景技术:
1、随着计算机技术的发展,图网络应运而生。图网络包含多个节点和多个边。通常采用图神经网络模型来挖掘图网络中节点的信息,以便后续进行节点分类、节点聚类、链路预测等下游任务。下游任务的效果取决于图神经网络模型所挖掘的信息。因此,如何对图神经网络模型进行训练,以便能够准确获取节点的信息成为本领域的研究重点。
2、相关技术中,通常是先基于领域知识对原视图进行扩充,得到扩充图;然后,通过对比原视图和扩充图之间的差异,来训练图神经网络模型;然后,通过训练后的图神经网络模型来挖掘原视图中节点的信息。
3、但是,上述方式所得到的节点的信息局限于上述领域知识。由于每个节点所表示的对象无论是用户、视频,还是其他事物,都不太可能局限于一个领域,而是一般包括多个维度的信息。因此,上述方式所挖掘的信息并不准确。
技术实现思路
1、本公开提供一种图神经网络模型的训练方法、装置、电子设备及存储介质,实现了通过图生成对抗学习来提升图神经网络模型的图对比学习性能的目的,使得图神经网络模型能够更准确地学习到真视图中的节点特征。本公开的技术方案如下:
2、根据本公开实施例的一方面,提供一种图神经网络模型的训练方法,包括:
3、获取真视图,所述真视图包括多个节点和多个第一边,所述真视图中的节点用于表示视频,所述第一边用于表示所述第一边连接的节点所代表的视频之间的关系;
4、通过图神经网络模型中的生成器,基于所述真视图中的节点,生成伪视图,所述伪视图包括所述多个节点和多个第二边,所述伪视图中的节点用于表示视频,所述第二边用于表示所述第二边连接的节点所代表的视频之间的关系;
5、通过所述图神经网络模型中的判别器,对所述伪视图进行判别,得到判别结果,所述判别结果用于表示将所述伪视图判别为真视图的概率;
6、基于所述判别结果,确定所述生成器和所述判别器之间的对抗损失,所述对抗损失用于反映生成的所述伪视图的质量;
7、通过所述图神经网络模型中的编码器,对所述真视图和所述伪视图进行对比学习,得到对比损失,所述对比损失用于表示真视图与伪视图之间的节点特征差异;
8、基于所述对抗损失和所述对比损失,对所述图神经网络模型进行训练,训练后的所述图神经网络模型用于视频分类。
9、在一些实施例中,所述通过图神经网络模型中的生成器,基于所述真视图中的节点,生成伪视图,包括:
10、基于至少一种分布策略,确定所述图神经网络模型中所述生成器的参数矩阵,所述参数矩阵中的任一数值用于表示对应的两个节点之间存在边的概率;
11、基于所述参数矩阵和所述真视图中的节点,生成所述伪视图。
12、在一些实施例中,所述方法还包括:
13、基于所述生成器的所述参数矩阵和所述真视图中的边,确定所述生成器的边数损失,所述参数矩阵中的任一数值用于表示生成的所述伪视图中对应的两个节点之间存在边的概率,所述边数损失用于表示所述生成器生成的边数与所述真视图中原有的边数之间的差异;
14、基于所述边数损失,调整所述参数矩阵。
15、在一些实施例中,所述基于所述边数损失,调整所述参数矩阵,包括:
16、基于原视图的邻接矩阵和所述参数矩阵,确定新边损失,所述原视图用于提供属于真视图的训练样本,作为所述训练样本的真视图中的边是所述原视图中边的子集,所述新边损失用于表示所述生成器生成的边与所述原视图中原有的边之间的差异;
17、基于所述边数损失和所述新边损失中的至少一种,调整所述参数矩阵。
18、在一些实施例中,所述基于所述对抗损失和所述对比损失,对所述图神经网络模型进行训练,包括:
19、基于所述边数损失、所述新边损失以及所述对抗损失,调整所述生成器的所述参数矩阵;
20、基于所述对比损失,调整所述编码器的编码参数。
21、在一些实施例中,所述方法还包括:
22、在两个节点属于原视图中任一边的情况下,基于新边比率、所述真视图的边数以及所述原视图的边数,确定所述参数矩阵中所述两个节点对应的数值,所述原视图用于提供属于真视图的训练样本,作为所述训练样本的真视图中的边是所述原视图中边的子集,所述新边比率用于限制所述生成器生成新边的数量,所述新边比率指的是所述伪视图中新边的数量与所述伪视图的总边数之间的比值,所述新边为所述原视图中不存在的边;
23、在两个节点属于候选新边集合中的任一边的情况下,基于所述新边比率、所述真视图的边数以及所述候选新边集合中的边数,确定所述参数矩阵中所述两个节点对应的数值,所述候选新边集合为包含所有新边的新边集合的子集;
24、在两个节点不属于所述原视图,也不属于所述候选新边集合的情况下,确定所述参数矩阵中所述两个节点对应的数值为零。
25、在一些实施例中,所述候选新边集合的确定过程,包括:
26、基于所述原视图中的所述多个节点,确定第一集合,所述第一集合包括所述多个节点之间可以存在的所有边;
27、从所述第一集合中过滤掉所述原视图中的边,得到第二集合,所述第二集合中的边针对于所述原视图的边而言是新边;
28、从所述第二集合中,选择度数满足条件的节点之间的边,构成所述候选新边集合。
29、在一些实施例中,所述通过所述图神经网络模型中的判别器,对所述伪视图进行判别,得到判别结果,包括:
30、通过所述图神经网络模型中的所述编码器,对所述伪视图中的多个节点进行特征提取,得到所述伪视图的节点特征;
31、对所述伪视图的节点特征进行平均池化,得到第一特征;
32、对所述伪视图的节点特征进行最大池化,得到第二特征;
33、基于所述第一特征和所述第二特征,确定所述伪视图的图特征;
34、通过所述图神经网络模型中的所述判别器对所述伪视图的图特征进行判别,得到所述判别结果。
35、在一些实施例中,所述方法还包括:
36、为所述伪视图添加标签,所述标签用于表示所述伪视图为所述生成器生成的伪视图;
37、基于所述判别结果和所述伪视图的标签,确定所述判别器的判别损失,所述判别损失用于表示所述判别器判别错误的概率;
38、基于所述判别损失,调整所述判别器中的判别参数。
39、在一些实施例中,所述通过所述图神经网络模型中的编码器,对所述真视图和所述伪视图进行对比学习,得到对比损失,包括:
40、通过所述图神经网络模型中的编码器,对所述真视图中的多个节点进行特征提取,得到所述真视图的节点特征;
41、通过所述编码器,对所述伪视图中的多个节点进行特征提取,得到所述伪视图的节点特征;
42、基于所述真视图的节点特征和所述伪视图的节点特征,确定所述对比损失,所述对比损失用于反映相同节点在不同视图之间的互信息。
43、在一些实施例中,所述通过所述图神经网络模型中的编码器,对所述真视图中的多个节点进行特征提取,得到所述真视图的节点特征,包括:
44、对于所述真视图中的任一节点,通过所述图神经网络模型中的编码器,对所述节点的初始特征和至少一个第一节点的初始特征进行聚合,得到所述节点的节点特征,所述至少一个第一节点是与所述节点具有连接关系的节点,所述节点的节点特征用于表示所述节点所表示的视频的特征。
45、在一些实施例中,所述基于所述对抗损失和所述对比损失,对所述图神经网络模型进行训练,包括:
46、基于所述对抗损失,调整所述生成器的参数矩阵,所述参数矩阵为所述生成器中的生成参数;
47、对于所述真视图中的任一节点,获取所述节点的第一节点和第二节点,所述第一节点与所述节点之间具有连接关系,所述第二节点与所述节点之间不具有连接关系;
48、基于所述节点的节点特征与所述第一节点的节点特征,确定第一特征距离;
49、基于所述节点的节点特征与所述第二节点的节点特征,确定第二特征距离;
50、基于所述第一特征距离和所述第二特征距离,确定贝叶斯个性化排名损失;
51、基于所述对比损失和所述贝叶斯个性化排名损失,调整所述编码器的编码参数。
52、根据本公开实施例的另一方面,提供一种图神经网络模型的训练装置,包括:
53、第一获取单元,被配置为执行获取真视图,所述真视图包括多个节点和多个第一边,所述真视图中的节点用于表示视频,所述第一边用于表示所述第一边连接的节点所代表的视频之间的关系;
54、生成单元,被配置为执行通过图神经网络模型中的生成器,基于所述真视图中的节点,生成伪视图,所述伪视图包括所述多个节点和多个第二边,所述伪视图中的节点用于表示视频,所述第二边用于表示所述第二边连接的节点所代表的视频之间的关系;
55、判别单元,被配置为执行通过所述图神经网络模型中的判别器,对所述伪视图进行判别,得到判别结果,所述判别结果用于表示将所述伪视图判别为真视图的概率;
56、第一确定单元,被配置为执行基于所述判别结果,确定所述生成器和所述判别器之间的对抗损失,所述对抗损失用于反映生成的所述伪视图的质量;
57、对比单元,被配置为执行通过所述图神经网络模型中的编码器,对所述真视图和所述伪视图进行对比学习,得到对比损失,所述对比损失用于表示真视图与伪视图之间的节点差异;
58、训练单元,被配置为执行基于所述对抗损失和所述对比损失,对所述图神经网络模型进行训练,训练后的所述图神经网络模型用于视频分类。
59、在一些实施例中,所述生成单元,被配置为执行基于至少一种分布策略,确定所述图神经网络模型中所述生成器的参数矩阵,所述参数矩阵中的任一数值用于表示对应的两个节点之间存在边的概率;基于所述参数矩阵和所述真视图中的节点,生成所述伪视图。
60、在一些实施例中,所述装置还包括:
61、第二确定单元,被配置为执行基于所述生成器的所述参数矩阵和所述真视图中的边,确定所述生成器的边数损失,所述参数矩阵中的任一数值用于表示生成的所述伪视图中对应的两个节点之间存在边的概率,所述边数损失用于表示所述生成器生成的边数与所述真视图中原有的边数之间的差异;
62、第一调整单元,被配置为执行基于所述边数损失,调整所述参数矩阵。
63、在一些实施例中,所述第一调整单元,被配置为执行基于原视图的邻接矩阵和所述参数矩阵,确定新边损失,所述原视图用于提供属于真视图的训练样本,作为所述训练样本的真视图中的边是所述原视图中边的子集,所述新边损失用于表示所述生成器生成的边与所述原视图中原有的边之间的差异;基于所述边数损失和所述新边损失中的至少一种,调整所述参数矩阵。
64、在一些实施例中,所述第一调整单元,被配置为执行基于所述边数损失、所述新边损失以及所述对抗损失,调整所述生成器的所述参数矩阵;基于所述对比损失,调整所述编码器的编码参数。
65、在一些实施例中,所述装置还包括:
66、第三确定单元,被配置为执行在两个节点属于原视图中任一边的情况下,基于新边比率、所述真视图的边数以及所述原视图的边数,确定所述参数矩阵中所述两个节点对应的数值,所述原视图用于提供属于真视图的训练样本,作为所述训练样本的真视图中边是所述原视图中边的子集,所述新边比率用于限制所述生成器生成新边的数量,所述新边比率指的是所述伪视图中新边的数量与所述伪视图的总边数之间的比值,所述新边为所述原视图中不存在的边;在两个节点属于候选新边集合中的任一边的情况下,基于所述新边比率、所述真视图的边数以及所述候选新边集合中的边数,确定所述参数矩阵中所述两个节点对应的数值,所述候选新边集合为包含所有新边的新边集合的子集;在两个节点不属于所述原视图,也不属于所述候选新边集合的情况下,确定所述参数矩阵中所述两个节点对应的数值为零。
67、在一些实施例中,所述装置还包括:第二获取单元,被配置为执行基于所述原视图中的所述多个节点,确定第一集合,所述第一集合包括所述多个节点之间可以存在的所有边;从所述第一集合中过滤掉所述原视图中的边,得到第二集合,所述第二集合中的边针对于所述原视图的边而言是新边;从所述第二集合中,选择度数满足条件的节点之间的边,构成所述候选新边集合。
68、在一些实施例中,所述判别单元,被配置为执行通过所述图神经网络模型中的所述编码器,对所述伪视图中的多个节点进行特征提取,得到所述伪视图的节点特征;对所述伪视图的节点特征进行平均池化,得到第一特征;对所述伪视图的节点特征进行最大池化,得到第二特征;基于所述第一特征和所述第二特征,确定所述伪视图的图特征;通过所述图神经网络模型中的所述判别器对所述伪视图的图特征进行判别,得到所述判别结果。
69、在一些实施例中,所述装置还包括:第二调整单元,被配置为执行为所述伪视图添加标签,所述标签用于表示所述伪视图为所述生成器生成的伪视图;基于所述判别结果和所述伪视图标签,确定所述判别器的判别损失,所述判别损失用于表示所述判别器判别错误的概率;基于所述判别损失,调整所述判别器中的判别参数。
70、在一些实施例中,所述对比单元,包括:
71、特征提取子单元,被配置为执行通过所述图神经网络模型中的编码器,对所述真视图中的多个节点进行特征提取,得到所述真视图的节点特征;通过所述编码器,对所述伪视图中的多个节点进行特征提取,得到所述伪视图的节点特征;
72、确定子单元,被配置为执行基于所述真视图的节点特征和所述伪视图的节点特征,确定所述对比损失,所述对比损失用于反映相同节点在不同视图之间的互信息。
73、在一些实施例中,所述特征提取子单元,被配置为执行对于所述真视图中的任一节点,通过所述图神经网络模型中的编码器,对所述节点的初始特征和至少一个第一节点的初始特征进行聚合,得到所述节点的节点特征,所述节点的节点特征用于表示所述节点所表示的视频的特征。
74、在一些实施例中,所述训练单元,被配置为执行基于所述对抗损失,调整所述生成器的参数矩阵,所述参数矩阵为所述生成器中的生成参数;对于所述真视图中的任一节点,获取所述节点的第一节点和第二节点,所述第一节点与所述节点之间具有连接关系,所述第二节点与所述节点之间不具有连接关系;基于所述节点的节点特征与所述第一节点的节点特征,确定第一特征距离;基于所述节点的节点特征与所述第二节点的节点特征,确定第二特征距离;基于所述第一特征距离和所述第二特征距离,确定贝叶斯个性化排名损失;基于所述对比损失和所述贝叶斯个性化排名损失,调整所述编码器的编码参数。
75、根据本公开实施例的另一方面,提供一种电子设备,该电子设备包括:
76、一个或多个处理器;
77、用于存储该处理器可执行程序代码的存储器;
78、其中,该处理器被配置为执行该程序代码,以实现上述图神经网络模型的训练的方法。
79、根据本公开实施例的另一方面,提供一种计算机可读存储介质,当该计算机可读存储介质中的程序代码由电子设备的处理器执行时,使得电子设备能够执行上述图神经网络模型的训练方法。
80、根据本公开实施例的另一方面,提供了一种计算机程序产品,包括计算机程序/指令,该计算机程序/指令被处理器执行时实现上述图神经网络模型的训练方法。
81、本公开实施例提供了一种图神经网络模型的训练方法,通过图神经网络模型中的生成器,基于真视图生成伪视图,然后通过图神经网络模型中的判别器对伪视图进行判别,来确定该伪视图的真伪,从而计算生成器和判别器之间的对抗损失,使得能够确定生成器生成视图的质量,后续通过对抗损失来对图神经网络模型进行训练,使得生成器生成的伪视图越来越接近于真视图,判别器越来越难以区分视图的真伪,也即是,生成器生成的视图的质量越来越高,然后通过图神经网络模型中的编码器,对高质量的伪视图与真视图进行对比学习,来得到对比损失,通过对比损失对图神经网络模型进行训练,实现了通过图生成对抗学习来提升图神经网络模型的图对比学习性能的目的,使得图神经网络模型能够更准确地学习到真视图中的节点特征,利于后续基于学习到的更准确的节点特征开展节点分类(例如视频分类)、链接预测等下游任务,提高了下游任务的准确性。
82、应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本公开。