基于原型一致性和自标记的联邦无监督模型训练及分类方法

文档序号:37885059发布日期:2024-05-09 21:29阅读:13来源:国知局
基于原型一致性和自标记的联邦无监督模型训练及分类方法

本发明属于数据处理,更进一步涉及图像处理与分类中的一种基于原型一致性和自标记的联邦无监督模型训练及分类方法。本发明可用于客户端利用本地的无标签图像数据协同训练分类模型,以及利用训练好的模型对无标签的图像进行分类。


背景技术:

1、随着智能设备的快速增加,数据的规模和来源呈现出多样化和复杂化的特征。联邦学习作为一种新兴的机器学习范式,已经成为最常用的一种隐私保护模型共享方法。联邦学习的核心思想是在客户端上利用本地数据训练本地模型,并将模型参数发送到服务器以聚合全局模型。现有的联邦学习方法通常只考虑有监督的训练设置,其中客户端数据被完全标记。然而,在实际应用场景中,由于数据的多样性和复杂性,使其通常缺乏有效的标注信息,比如涉及用户隐私的数据,用户可能并不愿意对其进行标注和共享。联邦无监督学习的技术方案,尝试通过基于对比学习的无监督学习方式(如simclr)解决数据无标签问题,对比学习依赖于正负样本对从无标签数据中学习通用表示,然而,一些属于同一个类别的实例样本,在学习过程中会不可避免的被视为负样本对,这会造成类别冲突问题,阻碍模型学习到更好的表示,导致模型性能下降。此外,边缘设备上的数据在实际生产环境大多是非独立同分布的,如何解决各边端设备上数据的异质分布,从而提高模型精度,也是一个亟待解决的重要问题。

2、北京邮电大学在其申请的专利文献“基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备”(申请号:202310205865.0,申请公布号:cn 116310530a,公布日期2023.06.23)中提出了一种联邦无监督图像分类模型训练方法、分类方法。该方法的实现步骤包括以下:(1)获取本地数据集,本地数据集包含多个样本,每个样本包含一张图像。(2)获取初始模型,初始模型包括语义聚类模型和预训练得到的编码器模型。(3)采用所本地数据集对初始模型进行训练,并构建聚类损失,利用聚类损失对初始模型的参数进行迭代,以得到初始图像分类模型。(4)将初始图像分类模型的模型参数发送至全局服务器,生成共享模型:接收共享模型的参数,并采用指数移动平均更新初始图像分类模型,得到最终的图像分类模型。该联邦无监督图像分类模型方法存在的不足之处是:该方法在预训练编码器阶段使用了基于正负样本的对比学习,会将属于同一个类别的实例样本视为负样本对,存在类别冲突问题,导致编码器学习程度有限,无法学习到更好的表示。

3、平安科技有限公司在其申请的专利文献“基于联邦学习的图像分类方法、装置、计算机设备及介质”(申请号:202310499593.x,申请公布号:cn 116433986a,公布日期2023.05.06)中提出了一种联邦学习图像分类方法。该方法实现的步骤包括以下:(1)将获取的无标签图像切分,得到n个子图像,从n个子图像中随机选择一个子图像进行遮挡,得到遮挡子图像,确定所有未被遮挡的子图像组成正常子图像集合。(2)接收服务器发送的初始模型,初始模型包括编码器和解码器,将正常子图像集合中的所有子图像输入编码器中进行特征提取,得到特征向量,将特征向量和遮挡子图像输入解码器中进行特征重构,得到重构图像。(3)根据无标签图像和重构图像对初始模型进行训练优化,得到优化模型,将优化模型中的编码器发送至服务器。(4)接收服务器发送的参考编码器,将参考编码器与部署好的分类器组成分类模型,根据获取的有标签图像及其标签对分类模型进行训练,得到更新模型。(5)将获取的待处理图像输入更新模型进行图像分类,得到图像分类结果。该方法存在的不足之处是:该方法在客户端本地模型更新阶段简单的使用了联邦平均算法进行模型更新,而客户端的数据通常为非独立同分布,与全局模型之间的差异较大,简单的使用联邦平均算法更新客户端模型会不利于客户端模型的学习,导致训练的模型精度下降。


技术实现思路

1、本发明的目的在于针对上述现有技术存在的不足,提出一种基于原型一致性和自标记的联邦学习无监督模型训练及分类方法,用于解决联邦无监督学习存在的类别冲突问题、数据分布为非独立同分布情况下训练模型性能较低的问题。

2、实现本发明目的的思路是,本发明在客户端本地训练阶段,通过两种不同的数据增强方式,利用本地无标签数据构建训练集和缓存集。通过深度聚类的方式对缓存集数据进行聚类,获的缓存集数据的伪标签和本地类原型。接着通过约束训练集中同一样本不同增强视图之间的一致性,以及训练集中批次类原型和缓存集中本地类原型之间的一致性,使得模型学习到均匀一致的表示,克服了使用正负样本对造成的类别冲突问题。此外,本发明设计了一种基于相邻样本一致性的自标记策略,首先通过计算每个样本与其他样本之间的距离,找到每个样本的若干最近邻样本;其次,判断每个样本的最近邻样本中,邻居样本伪标签与其自身伪标签一致的数量是否大于设定的阈值,若大于,则将该样本标记为高置信度样本;最后,将所有高置信度样本与本地类原型进行原型对比学习,对本地模型学习到的表示进行优化,以解决深度聚类过程中由于样本误分类导致的误差累积问题,提高本地模型的鲁棒性。在客户端模型更新阶段,我们设计了一种基于模型相似度的客户端模型更新策略,根据客户端模型与全局模型的相似程度,动态的更新客户端本地模型,以解决客户端数据为非独立同分布情况下模型精度较低的问题。

3、实现本发明目的的具体步骤如下:

4、步骤1,构建由特征提取子网络和预测器子网络串联组成联邦无监督全局模型;

5、步骤2,生成训练集、缓存集和测试集;

6、步骤3,服务器将联邦无监督全局模型发送至各客户端;

7、步骤4,每个客户端利用本地模型更新公式,使用全局模型对该客户端的本地模型进行更新;

8、步骤5,客户端对缓存集数据进行聚类,获得缓存集伪标签和本地类原型;

9、步骤6,客户端通过约束批次类原型与本地类原型的一致性,以及基于相邻样本一致性的自标记对客户端的本地模型进行本地训练;

10、步骤7,将所有训练好的联邦学习客户端模型参数进行加权聚合,得到全局模型;

11、步骤8,判断聚合后的断联邦无监督全局模型是否满足训练终止条件,若是,则执行步骤8,否则,将当前迭代次数加1后执行步骤3;

12、所述的训练终止条件指的是满足下述条件之一的情形:

13、条件1,联邦无监督全局模型性能达到设定的预期目标;

14、条件2,联邦无监督全局模型的损失函数收敛。

15、步骤9,得到训练好的联邦学习全局模型;

16、步骤10,利用训练好的联邦无监督模型,将待分类的每张图像进行归一化的数据增强处理后进行分类,输出所述图像的类别。

17、本发明与现有技术相比有以下优点:

18、第一,本发明的训练方法在客户端本地模型训练阶段,通过深度聚类方法获得本地数据的伪标签和类原型,通过约束同一样本不同增强视图之间的一致性以及批次类原型和本地类原型之间的一致性,克服了现有技术使用正负样本进行对比学习导致的类别冲突问题,使得本发明的客户端模型可以学习到均匀一致的表示,提高了客户端模型的性能。

19、第二,本发明的训练方法在客户端本地模型训练阶段,设计了一种基于相邻样本一致性的自标记策略,选择出高置信度的样本,与客户端本地类原型进行原型对比学习,使得本发明可以对客户端模型学习到的无监督表示进行优化,提高了客户端模型的鲁棒性。

20、第三,本发明的训练方法在全局模型聚合与本地模型更新阶段,设计了一种基于全局模型与本地模型相似度的客户端模型更新策略,克服了现有技术在客户端模型更新阶段因客户端数据异质分布造成的模型精度下降问题,使得本发明可以根据客户端的不同的数据分布动态灵活的更新客户端模型,聚合的全局模型可以充分学习客户端模型的知识,提高了整体联邦无监督模型的泛化能力。

21、第四,本发明的分类方法是基于原型一致性和自标记的联邦无监督模型训练方法训练好的模型,对待分类的图像进行分类,与现有的联邦无监督图像分类方法相比,经本发明训练好的模型在图像分类方面较现有方法可以大大提高了图像分类的准确率。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1