基于多任务学习的图结构预测模型训练方法和相关装置与流程

文档序号:37152300发布日期:2024-02-26 17:08阅读:13来源:国知局
基于多任务学习的图结构预测模型训练方法和相关装置与流程

本申请涉及人工智能领域,更具体的说,是涉及基于多任务学习的图结构预测模型训练方法和相关装置。


背景技术:

1、随着经济的快速发展,越来越多的客户期望在金融产品中进行投资;但是目前金融产品种类繁多,客户无法快速找到适合自己的金融产品,因此,向客户推荐金融产品的技术应运而生。

2、目前向客户推荐金融产品的方法是通过训练得到的机器学习模型向客户推荐金融产品。但是目前训练得到的机器学习模型不是很准确,导致向客户推荐的金融产品也不是很准确。

3、综上,如何训练得到比较准确的机器学习模型是本领域技术人员急需解决的技术问题。


技术实现思路

1、有鉴于此,本申请提供了一种基于多任务学习的图结构预测模型训练方法和相关装置。

2、为实现上述目的,本申请提供如下技术方案:

3、根据本公开实施例的第一方面,提供一种基于多任务学习的图结构预测模型训练方法,包括:

4、获取邻接矩阵a,所述邻接矩阵a中的每一元素为两个节点的关联值,所述两个节点中一个节点为客户或产品,另一个节点为客户或产品;若所述两个节点具有关联,所述两个节点的关联值为1,若所述两个节点不具有关联,所述两个节点的关联值为0;

5、获取所述邻接矩阵a涉及的所有节点的属性信息,以得到特征向量矩阵z,所述特征向量矩阵z中每一行向量为一个所述节点的属性信息的向量表示;

6、将所述邻接矩阵a以及所述特征向量矩阵z输入至多任务学习网络模型的输入端,通过所述多任务学习网络模型的第一输出端输出所述所有节点分别属于各个预设类型的预测概率,以及通过所述多任务学习网络模型的第二输出端输出预测邻接矩阵s;

7、其中,所述多任务学习网络模型包括链接预测模块以及节点分类模块,所述节点分类模块的输出端为所述第一输出端,所述链接预测模块的输出端为第二输出端;

8、基于所述所有节点分别属于各个预设类型的预测概率与所述所有节点分别对应的标注节点类型,获得第一损失函数;

9、基于所述邻接矩阵a与所述预测邻接矩阵s,获得第二损失函数;

10、通过所述第一损失函数与所述第二损失函数训练所述多任务学习网络模型;

11、返回所述获取邻接矩阵a,直至迭代次数达到预设阈值,以得到训练完毕的所述多任务学习网络模型;

12、确定所述多任务学习网络模型的所述第一输出端和所述第二输出端与全连接层相连,以得到图结构预测模型;所述图结构预测模型的输出端为所述全连接层的输出端,所述图结构预测模型的输入端为所述多任务学习网络模型的输入端;所述图结构预测模型用于输出包含产品的属性信息和客户的属性信息作为节点的图结构,所述图结构中任意的两两节点具有边表征两两节点具有关联性,若两两节点之间无边表征两两节点不具有关联性。

13、根据本公开实施例的第二方面,提供一种基于多任务学习的图结构预测模型训练装置,包括:

14、第一获取模块,用于获取邻接矩阵a,所述邻接矩阵a中的每一元素为两个节点的关联值,所述两个节点中一个节点为客户或产品,另一个节点为客户或产品;若所述两个节点具有关联,所述两个节点的关联值为1,若所述两个节点不具有关联,所述两个节点的关联值为0;

15、第二获取模块,用于获取所述邻接矩阵a涉及的所有节点的属性信息,以得到特征向量矩阵z,所述特征向量矩阵z中每一行向量为一个所述节点的属性信息的向量表示;

16、第三获取模块,用于将所述邻接矩阵a以及所述特征向量矩阵z输入至多任务学习网络模型的输入端,通过所述多任务学习网络模型的第一输出端输出所述所有节点分别属于各个预设类型的预测概率,以及通过所述多任务学习网络模型的第二输出端输出预测邻接矩阵s;

17、其中,所述多任务学习网络模型包括链接预测模块以及节点分类模块,所述节点分类模块的输出端为所述第一输出端,所述链接预测模块的输出端为第二输出端;

18、第四获取模块,用于基于所述所有节点分别属于各个预设类型的预测概率与所述所有节点分别对应的标注节点类型,获得第一损失函数;

19、第五获取模块,用于基于所述邻接矩阵a与所述预测邻接矩阵s,获得第二损失函数;

20、训练模块,用于通过所述第一损失函数与所述第二损失函数训练所述多任务学习网络模型;

21、触发模块,用于触发所述第一获取模块,直至迭代次数达到预设阈值,以得到训练完毕的所述多任务学习网络模型;

22、第一确定模块,用于确定所述多任务学习网络模型的所述第一输出端和所述第二输出端与全连接层相连,以得到图结构预测模型;所述图结构预测模型的输出端为所述全连接层的输出端,所述图结构预测模型的输入端为所述多任务学习网络模型的输入端;所述图结构预测模型用于输出包含产品的属性信息和客户的属性信息作为节点的图结构,所述图结构中任意的两两节点具有边表征两两节点具有关联性,若两两节点之间无边表征两两节点不具有关联性。

23、根据本公开实施例的第三方面,提供一种服务器,包括:

24、处理器;

25、用于存储所述处理器可执行指令的存储器;

26、其中,所述处理器被配置为执行所述指令,以实现如第一方面所述基于多任务学习的图结构预测模型训练方法。

27、根据本公开实施例的第四方面,提供一种计算机可读存储介质,当所述计算机可读存储介质中的指令由服务器的处理器执行时,使得服务器能够执行如第一方面所述基于多任务学习的图结构预测模型训练方法。

28、经由上述的技术方案可知,本申请提供了一种基于多任务学习的图结构预测模型训练方法,获取邻接矩阵a,邻接矩阵a中每一元素表征两两节点之间的关联关系;获取特征向量矩阵z,特征向量矩阵z中每一行向量为一个节点的属性信息的向量表示。将邻接矩阵a以及特征向量矩阵z输入至多任务学习网络模型,多任务学习网络模型包括链接预测模块以及节点分类模块,链接预测模块可以预测节点和节点之间是否有边,即链接预测模块可以得到预测邻接矩阵s;节点分类模块可以预测所有节点分别属于各个预设类型的预测概率;可以基于所有节点分别属于各个预设类型的预测概率与所有节点分别对应的标注节点类型,获得第一损失函数;基于邻接矩阵a与预测邻接矩阵s,获得第二损失函数;通过第一损失函数与第二损失函数训练多任务学习网络模型,训练完毕后,确定多任务学习网络模型的第一输出端和第二输出端与全连接层相连,以得到图结构预测模型;可以通过图结构预测模型输出包含产品的属性信息和客户的属性信息作为节点的图结构,从而可以基于图结构预测模型输出的图结构为客户推荐产品。由于多任务学习网络模型包括链接预测模块以及节点分类模块,链接预测模块分析了节点和节点之间潜在的关联性,节点分类模块分析了节点的潜在的类型,所以得到的图结构预测模型输出的图结构比较准确,从而基于图结构预测模型输出的图结构为客户推荐产品比较准确。



技术特征:

1.一种基于多任务学习的图结构预测模型训练方法,其特征在于,包括:

2.根据权利要求1所述基于多任务学习的图结构预测模型训练方法,其特征在于,还包括:

3.根据权利要求1或2所述基于多任务学习的图结构预测模型训练方法,其特征在于,所述第一损失函数lnc为:

4.根据权利要求3所述基于多任务学习的图结构预测模型训练方法,其特征在于,所述第二损失函数llp为:

5.根据权利要求4所述基于多任务学习的图结构预测模型训练方法,其特征在于,所述通过所述第一损失函数与所述第二损失函数训练所述多任务学习网络模型步骤包括:

6.一种基于多任务学习的图结构预测模型训练装置,其特征在于,包括:

7.根据权利要求6所述基于多任务学习的图结构预测模型训练装置,其特征在于,还包括:

8.根据权利要求6或7所述基于多任务学习的图结构预测模型训练装置,其特征在于,所述第一损失函数lnc为:

9.一种服务器,其特征在于,包括:

10.一种计算机可读存储介质,当所述计算机可读存储介质中的指令由服务器的处理器执行时,使得服务器能够执行如权利要求1至5中任一项所述基于多任务学习的图结构预测模型训练方法。


技术总结
本申请公开了基于多任务学习的图结构预测模型训练方法和相关装置,可应用于人工智能领域。本申请可用于为客户推荐产品。获取邻接矩阵A以及特征向量矩阵Z,并输入至多任务学习网络模型,多任务学习网络模型中的链接预测模块可以得到预测邻接矩阵S,节点分类模块可以预测所有节点分别属于各个预设类型的预测概率;通过节点分类模块对应的第一损失函数以及链接预测模块对应的第二损失函数训练多任务学习网络模型,训练完毕后,将多任务学习网络模型与全连接层连接,以构建得到图结构预测模型;基于图结构预测模型输出的图结构为客户推荐产品。由于使用了多任务学习网路模型,所以得到的图结构比较准确,从而基于图结构为客户推荐产品比较准确。

技术研发人员:金沛璇
受保护的技术使用者:中国农业银行股份有限公司
技术研发日:
技术公布日:2024/2/25
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1