本发明涉及航空发动机振动信号故障诊断技术中,数据不充分导致的数据孤岛问题以及传统联邦学习缺乏个性化解决方案的问题。
背景技术:
1、航空发动机一直以来都被视为影响国家空中运输、国防安全和保持国家战略优势的核心技术,是衡量一个国家综合科技水平、科技工业基础实力和综合国力的重要标志。然而在整个航空航天行业,数据分散的储存在不同的工厂和企业中,各方数据在物联网节点之间的流通面临困难和挑战。联邦学习通过服务器聚合各客户端训练的本地模型,平均后再下发至各客户端,在航空航天领域故障诊断方面有广阔的发展前景。
2、然而联邦学习在高度异构的数据上会出现客户端漂移现象,导致fedavg等算法得到的全局模型并非是最优全局模型,存在偏差;下发至各客户端的全局模型缺乏个性化的解决方案,全局模型对于普通客户端来说适用,但对于特殊客户端不适用,最终导致振动故障诊断的效果不理想。
技术实现思路
1、为改善以上问题,本发明提出了一种基于知识蒸馏的个性化联邦学习方法。实验表明,该方法能够有效的解决各客户端数据异构导致模型收敛性差的问题,为各客户端训练独特的个性化模型。
2、一种基于个性化联邦学习的故障诊断模型优化方法,其特征在于,包括以下步骤:
3、(1)客户端初始化本地模型,所述本地模型为深度神经网络模型;
4、(2)所述客户端通过公共数据集补充其私有数据,然后进行本地训练,并将更新后的所述本地模型梯度参数传递给下一客户端,作为全局模型的梯度参数,全局模型为深度神经网络模型;
5、(3)所述下一客户端接收所述全局模型梯度参数,通过公共数据集补充其私有数据进行本地训练,基于知识提取算法更新全局模型的梯度参数,重复步骤(3),直到全局模型收敛或者达到指定训练次数;
6、(4)同类别客户端接收所述全局模型梯度参数,根据其私有数据进行本地训练,并将更新后的梯度参数传递给下一客户端,作为同类别全局模型的梯度参数,同类别全局模型为深度神经网络模型;
7、(5)所述下一客户端接收所述同类别全局模型梯度参数,根据其私有数据进行本地训练,基于知识提取算法更新同类别全局模型梯度参数,重复步骤(5),直到同类别全局模型收敛或者达到指定训练次数;
8、(6)各客户端根据所述同类别全局模型梯度参数,通过知识蒸馏算法对本地模型进行个性化。
9、进一步地,所述全局模型由全体客户端进行训练、所述同类别全局模型由同类别客户端进行训练,所述本地模型由同类别全局模型进行蒸馏。
10、进一步地,所述步骤(2)中公共数据集是将本地数据集通过不等份划分和频域置乱处理后拼接而成的,无法反向破解原始本地数据集的信息。
11、进一步地,所述步骤(3)中知识提取算法包括:根据本地训练集和交叉熵损失函数训练本地模型;根据所述全局模型、下一客户端的本地模型构建全局损失函数;下一客户端根据知识提取,可以很好的利用来自上一个客户端的知识,将其视为公共知识。
12、进一步地,所述全局损失函数为:
13、
14、其中:λ是知识迁移和对当前数据的权衡,lcls是交叉熵损失。fi=ci·gi,其中ci是分类层,gi是特征提取层。
15、进一步地,所述知识提取函数为:
16、
17、其中gtea是前一个联合的特征提取器,而gstu是当前训练联合的,x是当前联合的数据样本。
18、进一步地,所述步骤(5)中知识提取算法包括:根据本地训练集和交叉熵损失函数训练本地模型;根据所述全局模型、下一客户端的本地模型构建全局损失函数;下一客户端根据知识提取,可以很好的利用来自上一个客户端的知识,将其视为公共知识。
19、进一步地,所述全局损失函数为:
20、
21、其中:λ是知识迁移和对当前数据的权衡,lcls是交叉熵损失。fi=ci·gi,其中ci是分类层,gi是特征提取层。
22、进一步地,所述知识提取函数为:
23、
24、其中gtea是前一个联合的特征提取器,而gstu是当前训练联合的,x是当前联合的数据样本。
25、进一步地,所述步骤(6)中知识蒸馏函数为:
26、
27、当上一客户端的表现非常糟糕时,设置λ=0;当上一客户端的表现可以接受时,适当调整λ进行个性化。
28、发明效果
29、本发明方法在传统联邦学习框架的基础上,取消了中央服务器端的设置,在避免数据隐私泄露的同时,提高了各节点之间的传输效率。另外,本发明方法设立了公共数据集进行补充各节点缺失的故障类型,解决了高度异构数据集收敛性差的问题;相同故障类型的节点在公共知识积累完成后,再进行同类别知识积累,解决了不同节点缺少个性化方案的问题;同类别知识积累阶段,不同故障类别的节点是同时进行的,对于之前的各节点依序积累知识,大大缩短了耗时,提高了联邦学习的效率,实现了各客户端本地模型的个性化。
1.一种基于个性化联邦学习的故障诊断模型优化方法,其特征在于,包括以下步骤:
2.根据权利要求1所述的一种用于故障诊断的个性化联邦学习方法,其特征在于,所述全局模型由全体客户端进行训练、所述同类别全局模型由同类别客户端进行训练,所述本地模型由同类别全局模型进行蒸馏。
3.根据权利要求1所述的一种用于故障诊断的个性化联邦学习方法,其特征在于,所述步骤(2)中公共数据集是将本地数据集通过不等份划分和频域置乱处理后拼接而成的,无法反向破解原始本地数据集的信息。
4.根据权利要求1所述的一种用于故障诊断的个性化联邦学习方法,其特征在于,所述步骤(3)中知识提取算法包括:
5.根据权利要求4所述的一种用于故障诊断的个性化联邦学习方法,其特征在于,所述全局损失函数为:
6.根据权利要求4所述的一种用于故障诊断的个性化联邦学习方法,其特征在于,所述知识提取函数为:
7.根据权利要求1所述的一种用于故障诊断的个性化联邦学习方法,其特征在于,所述步骤(5)中知识提取算法与权利要求4相同。
8.根据权利要求1所述的一种用于故障诊断的个性化联邦学习方法,其特征在于,所述步骤(6)中知识蒸馏函数为: