一种联邦学习方法和系统与流程

文档序号:36644421发布日期:2024-01-06 23:28阅读:41来源:国知局
一种联邦学习方法和系统与流程

本说明书涉及人工智能领域,尤其涉及一种联邦学习方法和系统。


背景技术:

1、图形可以表示现实世界中的数据关系,比如,银行端设备的子图可以表示不同客户之间的金融交易关系、通信运营商端设备的子图可以表示不同客户之间的通讯关系,等等。通常,每个端设备的子图数据是高度保密的,禁止与其他端设备共享。

2、不同端设备的用户之间可能有通信往来,或者同一个用户可能在不同的端设备都有数据,因此不同端设备之间通常相互联系,不同端设备的子图之间相互耦合,共同形成了一个耦合图。而由于端设备的子图数据禁止与其他端设备共享,导致端设备在训练本地模型时无法从其他端设备获取到与自身相关的数据,导致数据缺失,本地模型的精度较低。因此,需要一种能提升模型精度的联邦学习方法。

3、背景技术部分的内容仅仅是发明人个人所知晓的信息,并不代表上述信息在本公开申请日之前已经进入公共领域,也不代表其可以成为本公开的现有技术。


技术实现思路

1、本说明书提供的联邦学习方法和系统,可以提升本地模型的精度。

2、第一方面,本说明书提供一种联邦学习方法,应用于目标端设备,所述方法包括:

3、确定目标子图,所述目标子图包括连接的目标内部节点,所述目标内部节点中包括与其他端设备连接的目标边缘节点;确定所述目标子图的目标扩展图,所述目标扩展图包括所述目标子图以及所述目标边缘节点与目标外部节点连接的图,所述目标外部节点为所述其他端设备的边缘节点,且所述目标边缘节点为所述其他端设备的外部节点;以及基于所述目标扩展图执行多次循环训练,得到目标本地模型,在每次循环中:确定所述目标边缘节点的聚合外部嵌入特征,所述聚合外部嵌入特征聚合了来自所述目标外部节点的非原始特征,基于所述聚合外部嵌入特征对全局模型进行训练,得到本地模型,所述全局模型的参数为上一次循环得到的全局参数,将所述本地模型的本地参数发送给服务器,其中,所述服务器基于多个端设备对应的多个本地模型的本地参数确定全局参数,所述多个端设备至少包括所述目标端设备和所述其他端设备,以及接收所述服务器下发的所述全局参数,并基于所述全局参数更新所述本地模型的本地参数。

4、在一些实施例中,所述确定所述目标边缘节点的聚合外部嵌入特征之前,还包括:基于所述目标扩展图对所述全局模型进行预训练,得到预训练本地模型。

5、在一些实施例中,所述基于所述目标扩展图对所述全局模型进行预训练,得到预训练本地模型,包括:基于所述目标扩展图对所述全局模型进行预设轮数的对比学习训练,得到所述预训练本地模型。

6、在一些实施例中,在所述预设轮数的每一轮训练中:确定对比模型,所述对比模型为上一轮训练得到的全局模型;以及基于所述目标扩展图和所述对比模型对所述全局模型进行对比学习训练,得到所述全局模型在当前轮的更新参数。

7、在一些实施例中,所述基于所述目标扩展图和所述对比模型对所述全局模型进行对比学习训练,得到所述全局模型在当前轮的更新参数,包括:将所述目标扩展图输入所述对比模型中进行编码,输出所述目标扩展图中每个目标内部节点的第一嵌入特征;将所述目标扩展图输入所述当前轮的全局模型中进行编码,输出所述目标扩展图中每个目标内部节点的第二嵌入特征;对所述每个目标内部节点:基于所述第二嵌入特征及其对应的真实标签,确定分类损失,基于当前目标内部节点的所述第二嵌入特征以及所述每个目标内部节点的所述第一嵌入特征,得到对比学习损失,所述对比学习损失约束所述当前目标内部节点的所述第二嵌入特征和与其类别相同的目标内部节点的第一嵌入特征的距离接近,并约束所述当前目标内部节点的所述第二嵌入特征和与其类别不同的目标内部节点的第一嵌入特征的距离远离,以及对所述分类损失和所述对比学习损失进行加权求和,得到第一损失;以及基于所述每个目标内部节点的所述第一损失确定所述更新参数。

8、在一些实施例中,所述服务器包括边缘服务器和云服务器,其中,所述边缘服务器与所述目标端设备的距离比所述云服务器相对于所述目标端设备的距离近。

9、在一些实施例中,所述确定所述目标边缘节点的聚合外部嵌入特征,包括:基于所述预训练本地模型,得到所述目标外部节点的嵌入特征,并发送给所述边缘服务器;以及接收所述边缘服务器发送的所述聚合外部嵌入特征,其中,所述聚合外部嵌入特征是所述边缘服务器对所述其他端设备发送的所述目标边缘节点的嵌入特征进行聚合得到的,所述其他端设备发送的所述目标边缘节点的嵌入特征包含了所述目标外部节点的特征。

10、在一些实施例中,所述基于所述聚合外部嵌入特征对全局模型进行训练,得到本地模型,包括:基于所述聚合外部嵌入特征对所述预训练本地模型进行训练,得到所述本地模型。

11、在一些实施例中,所述基于所述聚合外部嵌入特征对所述预训练模型进行训练,得到本地模型,包括:将所述目标扩展图输入至所述预训练本地模型,得到所述目标边缘节点的聚合内部嵌入特征,所述聚合内部嵌入特征聚合了来自与所述目标边缘节点连接的目标内部节点的特征;将所述聚合内部嵌入特征与所述聚合外部嵌入特征进行聚合,得到所述目标边缘节点的聚合嵌入特征;以及确定所述聚合嵌入特征与所述目标边缘节点的真实标签之间的第二损失,并基于所述第二损失更新所述预训练模型的参数,从而得到所述本地模型。

12、在一些实施例中,所述全局参数包括对所述多个本地模型的本地参数进行聚合得到的参数。

13、第二方面,本说明书提供一种联邦学习系统,包括目标端设备,所述目标端设备包括:至少一个存储介质,存储有至少一组指令集用于实现所述联邦学习;以及至少一个处理器,同所述至少一个存储介质通信连接,其中当所述联邦学习系统运行时,所述至少一个处理器读取所述至少一个指令集并实施第一方面中任一项所述的联邦学习方法。

14、第三方面,本说明书提供一种联邦学习方法,应用于服务器,所述方法包括多次循环训练,在每次循环中:接收多个端设备分别发送的外部节点的嵌入特征,对于所述多个端设备中的目标端设备:所述目标端设备包括目标外部节点和与其他端设备连接的目标边缘节点,所述目标端设备的目标外部节点为所述其他端设备的边缘节点,且所述目标端设备的目标边缘节点为所述其他端设备的外部节点;对所述多个端设备发送的所述嵌入特征进行聚合,得到每个端设备的边缘节点的聚合外部嵌入特征;将所述聚合外部嵌入特征发送给对应的端设备,其中,每个端设备基于所述聚合外部嵌入特征对全局模型进行训练得到本地模型,所述全局模型的参数为上一次循环得到的参数;接收所述多个端设备发送的多个本地模型的本地参数,并对所述多个本地模型的本地参数进行聚合,得到全局参数;以及向所述多个端设备分别发送所述全局参数,所述全局参数被配置为更新所述多个本地模型中的本地参数。

15、在一些实施例中,所述服务器包括边缘服务器和云服务器,其中,所述边缘服务器与所述目标端设备的距离比所述云服务器相对于所述目标端设备的距离近。

16、在一些实施例中,所述接收多个端设备分别发送的外部节点的嵌入特征,包括:所述边缘服务器接收所述多个端设备的所述嵌入特征。

17、在一些实施例中,所述对所述多个端设备发送的所述嵌入特征进行聚合,得到每个端设备的边缘节点的聚合外部嵌入特征,包括对于所述目标端设备:所述边缘服务器对所述其他端设备发送的所述目标端设备的目标边缘节点的嵌入特征进行聚合,得到所述目标端设备的目标边缘节点的聚合外部嵌入特征,所述其他端设备发送的所述目标端设备的目标边缘节点的嵌入特征包含了所述目标端设备的目标外部节点的特征。

18、在一些实施例中,所述聚合包括相加。

19、在一些实施例中,所述接收所述多个端设备发送的多个本地模型的本地参数,包括:所述云服务器接收所述边缘服务器转发的所述多个本地模型的本地参数。

20、在一些实施例中,所述对所述多个本地模型的本地参数进行聚合,得到全局参数,包括:确定所述多个本地模型对应的多个模型权重,所述模型权重与对应端设备的边缘节点的数量成反比;以及将所述多个模型权重与所述多个本地模型的本地参数进行加权聚合,得到所述全局参数。

21、第四方面,本说明书还提供一种联邦学习系统,包括服务器,所述服务器包括:至少一个存储介质,存储有至少一组指令集用于实现所述联邦学习;以及至少一个处理器,同所述至少一个存储介质通信连接,其中当所述联邦学习系统运行时,所述至少一个处理器读取所述至少一个指令集并实施第三方面中任一项所述的联邦学习方法。

22、由以上技术方案可知,本说明书提供的联邦学习方法,目标端设备由目标子图扩展出包含目标外部节点的目标扩展图,通过聚合来自目标外部节点的非原始特征从而得到其目标边缘节点的聚合外部嵌入特征,即在保护其他端设备数据隐私的前提下引入了其他端设备中与自身相关的信息,使得目标边缘节点的信息更全面,利用目标边缘节点训练得到的本地模型的精度更高。同时,目标端设备通过服务器下发的全局参数更新本地参数,即结合了其他端设备的本地参数共同更新本地模型的本地参数,提高了本地模型的泛化性。

23、本说明书提供的联邦学习方法和系统的其他功能将在以下说明中部分列出。根据描述,以下数字和示例介绍的内容将对那些本领域的普通技术人员显而易见。本说明书提供的联邦学习方法和系统的创造性方面可以通过实践或使用下面详细示例中所述的方法、装置和组合得到充分解释。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1