一种基于模型压缩和数据加密的联邦学习方法与流程

文档序号:37835352发布日期:2024-05-07 19:10阅读:8来源:国知局
一种基于模型压缩和数据加密的联邦学习方法与流程

本发明涉及人工智能,具体而言,涉及一种基于模型压缩和数据加密的联邦学习方法。


背景技术:

1、随着数字化进程加快,产生了大量数据。通过机器学习技术可以自动化地挖掘数据中蕴藏的宝藏,经过大量数据训练出来的机器学习模型已经应用在各类场景中,正在深刻改变着我们的世界,例如精准医疗、临床辅助诊断、新药研发、人像识别、声纹识别、千人千面推荐算法、图片、语音、自然语言等多模态学习。在应用中,模型的精度、泛化能力等至关重要,而这些都赖于机器对大量数据的学习。受限于法律法规、政策监管、商业机密、个人隐私等数据隐私安全上的约束,多个数据来源方无法直接交换数据,形成“数据孤岛”现象,制约着人工智能模型能力的进一步提高。

2、联邦学习的诞生即是为了解决这一问题。联邦学习(federated learning)是一种机器学习框架,能帮助多个机构在满足用户隐私保护、数据安全和政府法规的要求下,进行数据使用和机器学习建模。联邦学习可在不用将用户的数据集上传到云端的情况下,聚合多方的资源并进行深度学习训练。

3、但是,由于联邦学习需要在云端和边缘端进行数据的传输,所以存在数据泄露和通信消耗过大的问题。

4、在传统联邦学习中,云端将模型参数分发给边缘端后,由各个边缘端在本地进行训练,之后将梯度传输给云端,由云端进行安全聚合,当前采用的最普遍的做法是联邦平均算法fedavg,也就是将各个边缘端的梯度进行加权平均后,对云端的模型进行更新。

5、然而,在云端将梯度分发给边缘端时,可能会被恶意用户劫持,进而反推出用户的真实数据,导致数据泄露;并且在传统联邦学习中需要频繁传输梯度,在具有大量边缘端的情况下,该问题会更加严重。


技术实现思路

1、鉴于此,本发明的目的在于设计一种基于模型压缩和数据加密的联邦学习方法,针对数据传输过程中通信消耗过大的问题,采用模型压缩的方法尽可能地减少云端和边缘端之间的数据传输消耗;同时在数据传输过程中通过数据加密的方式保证传输安全性。

2、本发明提供一种基于模型压缩和数据加密的联邦学习方法,包括:

3、s1、在云端生成公钥和私钥,并由云端将私钥分发给各个边缘端,云端保留公钥;

4、s2、由云端选择需要进行联邦学习的边缘端设备,并将数据经过公钥加密后分发给所选择的边缘端;

5、由于边缘端设备并不是一直在线,所以由云端选择需要进行联邦学习的边缘端设备;

6、s3、被选择的边缘端利用自己的数据训练一个模型,通过自蒸馏提高模型的精度;使用所述模型对云端的数据进行解密,并进行预测,发送软标签给云端;

7、s4、由云端对各个边缘端发送的软标签进行加权,通过知识蒸馏提取各个边缘端模型的暗知识;知识蒸馏过程中将云端作为知识蒸馏中的学生,将各个边缘端作为教师;

8、知识蒸馏(knowledge distillation)是将学习能力强的复杂教师模型中的“知识”迁移到简单的学生模型中的一种模型压缩方法。

9、s5、由云端对训练完的模型进行结构化剪枝,并将剪枝后的模型数据使用公钥加密后分发给边缘端;

10、模型剪枝(model pruning)是一种通过减小模型权重规模或中间状态规模来减小模型大小和计算量的技术。剪枝后的模型参数量大幅减少,有效地降低了通信消耗。

11、s6、使用各个边缘端的本地数据更新模型,并对云端的数据进行预测,由各个边缘端再次发送软标签给云端;

12、s7、重复s3-s6步骤,实现云端集成各个边缘端的数据信息,将云端模型分发给各个边缘端。

13、进一步地,所述s3步骤的自蒸馏的方法包括:

14、设当前共有n个边缘端,每个边缘端的训练epoch为e,针对其中的一个边缘端,在第t个epoch时,模型对本轮epoch输出的概率作为软标签,t-1轮的概率作为教师模型输出的软标签,针对第t-1轮的目标标签进行纠正,保证学生学习到的知识一定是正确的;

15、对于其中的某一张样本标签,当第t轮和t-1轮都是错的时,则说明对于模型,该标签的学习难度较大,需尽可能地增强,将第t-1轮修正为one-hot标签,也即原始标签;

16、当第t轮是错的,第t-1轮是对的时,则说明模型对于该标签可能过拟合了,需降低第t-1轮标签对于true lable的预测概率,减少模型过拟合情况的出现;

17、纠正完之后,将第t-1轮的标签作为软标签输入给模型进行训练。

18、在传统的联邦学习中,当模型精度达到预期值时就会结束训练,但是由于边缘端会不断地产生数据,结束训练后的云端模型无法进行更新。本发明通过改进在线自蒸馏的技术能够持续地从边缘端进行学习,进而优化模型,达到更高的精度。

19、进一步地,所述s4步骤的知识蒸馏的方法包括:

20、由云端使用各个边缘端发送过来的软标签进行一轮知识蒸馏过程,这里以计算机视觉领域为例,设教师模型为a:resnet34,学生模型为b:resnet18,助理模型为c:resnet20;

21、s41、引入规模介于教师模型和学生模型之间的助理模型c,使用a对c进行蒸馏;

22、s42、使用教师模型a、助理模型c,采用feature蒸馏和logit蒸馏的方式对学生模型b进行蒸馏;

23、s43、计算模型包含知识蒸馏损失的总损失函数:

24、lossall=(1-α)×losskd+α×lossce  (1)

25、式(1)中,lossce表示传统的交叉熵函数,losskd表示知识蒸馏损失,ɑ表示超参数,用来平衡二者的大小。

26、进一步地,所述s42步骤的logit蒸馏的方式为:将教师模型a,c的logit传给b;

27、feature蒸馏的方式为:仅融入教师模型第一个模块和最后一个模块的知识。融入第一个模块的目的在于初步将教师模型知识传递给学生模型进行学习,融入最后一个模块知识的目的是再次强化学生模型学习结果。

28、现有的知识蒸馏技术通过软标签来学习教师模型,但是在网络较深时,学生模型无法有效地学习到教师模型浅层的暗知识。本发明通过改进知识蒸馏的方法,仅在云端和边缘端之间传输软标签,能够大量减少网络带宽的压力。

29、进一步地,所述s5步骤的结构化剪枝的方法包括:

30、s51、计算训练完成的云端模型中每个过滤器之间的相似度,并生成相似度矩阵;

31、s52、将相似度较高的过滤器使用聚类算法聚为一类,类的个数由剪枝率决定;

32、s53、在每个类中选择保留一个过滤器,其他的过滤器都修剪掉,对于只包含一个过滤器的类,不进行修剪,保证剩余的每个过滤器都能尽可能地提取到有差异性的特征,完成剪枝。

33、进一步地,所述s51步骤的计算过滤器之间的相似度的方法为:计算过滤器之间的皮尔逊(person)相关系数;

34、协方差在一定程度上能够表示两个过滤器的相似度,但是会受到量纲的影响,通过协方差的正负来判断两个过滤器的相似度并不准确,所以本发明采用皮尔逊(person)相关系数;

35、设存在两个过滤器a和b,则a和b的皮尔逊相关系数的计算方法是两个过滤器的协方差除以两个过滤器的标准差,计算公式为:

36、

37、式(2)中,ai代表过滤器a中的每一个元素,代表过滤器a的样本平均值;bi代表过滤器b中的每一个元素,代表过滤器b的样本平均值;n代表过滤器a或过滤器b中的元素总个数;

38、当皮尔逊相关系数为1时,两个过滤器完全正相关;当皮尔逊相关系数为-1时,两个过滤器完全负相关;皮尔逊相关系数的绝对值越大,两个过滤器的相似度越大;皮尔逊相关系数越接近于0,两个过滤器的相似度越小。

39、进一步地,所述s52步骤的将相似度较高的过滤器使用聚类算法聚为一类的方法包括:

40、本发明为了降低链式效应的影响,使用基于层次的聚类算法,具体的,采用凝聚聚类算法,所述凝聚聚类算法采用自底向上的思路对过滤器进行聚类。

41、进一步地,所述s53步骤的在每个类中选择保留一个过滤器时包括以下两种情况:

42、如果当前只存在一个过滤器,那么只能保留该过滤器,而不能进行剪枝;因为在经过聚类分类后,具有相似提取特征的过滤器已经被分为一类,所以删除唯一的一个过滤器将会导致模型无法提取到特定的特征。

43、如果当前存在多个过滤器,则计算这些过滤器的几何中位数,然后选择与几何中位数最近的过滤器进行保留。

44、比如,对于模型中的第i层,在经过s52步骤后生成了10个类,其中前3个类中只有1个过滤器,剩下的类中过滤器的个数多于1个,则对前3个类不进行修剪,直接保留,对于剩下的每一个类,分别计算类中的几何中位数,然后计算类中的所有过滤器到几何中位数的距离,选择与几何中位数最近的过滤器保留,修剪其他的过滤器。

45、现有的模型剪枝技术采用l1范数或者l2范数判断过滤器的重要性,而忽略了过滤器间的相似性。本发明采用的种剪枝方法首先计算皮尔逊相关系数来判断相似性,并且对过滤器进行聚类,然后通过几何中位数来判断过滤器的重要性进行剪枝,这样能够更加合理的修剪冗余的过滤器,降低了模型的参数,以及模型对于算力的要求,使得模型在资源受限的边缘端设备上都可以部署。

46、本发明还提供一种计算机可读存储介质,其上存储有计算机程序,所述程序被处理器执行时实现如上述所述的基于模型压缩和数据加密的联邦学习方法的步骤。

47、本发明还提供一种计算机设备,所述计算机设备包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如上述所述的基于模型压缩和数据加密的联邦学习方法的步骤。

48、与现有技术相比,本发明的有益效果在于:

49、本发明基于模型压缩和数据加密的联邦学习方法通过引入助理模型减少教师模型和学生模型间的容量差距,在网络的浅层和深层进行特征蒸馏,进一步地传递教师模型的知识;改进知识蒸馏的方法,仅在云端和边缘端之间传输软标签,大量减少了网络带宽的压力;通过计算皮尔逊相关系数判断相似度,并且对过滤器进行聚类,通过几何中位数判断过滤器的重要性,进行结构化剪枝,更加合理地修剪冗余的过滤器,降低了模型的参数量以及模型对于算力的要求,使得模型在资源受限的边缘端设备上都可以部署;通过改进在线自蒸馏,能够持续地从边缘端进行学习,进而优化模型,使模型达到更高的精度;采用数据加密,避免了数据传输时被恶意外部用户劫持导致数据泄露的问题,有效提高了数据传输过程中的隐私性和安全性。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1