一种基于混合增强对比的黑盒源域无监督领域自适应方法

文档序号:35965253发布日期:2023-11-09 04:41阅读:48来源:国知局
一种基于混合增强对比的黑盒源域无监督领域自适应方法

本发明属于机器学习下的迁移学习,涉及一种领域自适应模型方法,具体的说是涉及一种基于混合增强对比的黑盒无监督领域自适应方法。


背景技术:

1、随着大数据时代的到来,数据产生速度不断加快,数据规模呈现爆发式增长,这使得有能力处理庞大数据集的机器学习技术备受关注。大量数据为机器学习和深度学习提供了更多训练和优化的机会,从而提升了模型的性能和适用性。尽管机器学习在许多领域取得了令人瞩目的成功,但在现实场景中仍然存在着一些限制。传统的机器学习方法通常依赖于大量标记数据来构建模型,以实现较高的分类性能。然而,获取大规模标记数据并不总是容易或实际可行的。这就带来了一个新的挑战,即如何在有限的标记数据条件下训练出具有良好泛化能力的模型,并能够准确地预测未标记数据。

2、迁移学习旨在利用已经训练好的模型中的知识和特征,将其应用于新任务或领域中以提升性能。通过迁移学习,我们可以将一个领域中的知识和经验转移到另一个相关或类似的领域,从而节省大量时间和资源。领域自适应是迁移学习的一个分支,它关注的是不同领域之间的知识迁移。在现实场景中,不同领域的数据可能存在领域间差异,例如图像的拍摄环境、文本的语言风格等,这些领域差异会影响模型在目标领域上的性能。领域自适应旨在通过减小不同领域间差异,使模型能够在目标领域上具有较好的泛化能力,其中一种常见的领域自适应方法是无监督领域自适应,它利用目标领域中未标记的数据进行训练。无监督领域自适应通常通过学习领域间的共享特征或对抗性学习来实现,而无需目标领域的标记数据。

3、尽管无监督领域自适应取得了显著的成功,但人们对数据隐私的日益关注给这项任务带来了新的挑战。源域和目标域的数据通常储存在不同的设备上并包含私人信息,因此将源域数据暴露给目标域存在一定的风险,换言之,已经标记的源域数据可能无法为目标模型所用,这就使得一些现有的无监督领域自适应方法不再适用,因此便有了无源领域自适应方法,以促进模型迁移并保护源数据的隐私安全。无源领域自适应向未标记的目标域提供训练有素的源模型而非已经标记的源域数据,因此无源领域自适应也称为白盒领域自适应。

4、然而在实际应用中,白盒源域模型并不总是能获得的。常见的云服务模型如谷歌云,腾讯云,被封装为应用程序编程接口的形式提供给用户,其中只有模型的输入输出接口可用,模型本身被保存为黑盒接口,这使大量无源领域自适应方法在实践中变得不可用,为此黑盒领域自适应诞生。黑盒领域自适应方法只能使用源域模型的接口访问,在安全性提高的同时也给领域自适应任务带来了不小的挑战,无法获得源模型输出的样本特征使解决域偏移问题变得困难,源模型接口信噪比的不确定也使伪标签变得不可靠。


技术实现思路

1、为了解决上述技术问题,本发明提供了一种基于混合增强对比的黑盒源域无监督领域自适应方法,该方法在基于知识蒸馏模型的基础上,增加了改进的混合特征对比模块、早期学习正则化模块和随机混合增强模块,帮助学习源域和目标域间共享类的知识和目标域私有类的知识,有效地提高了目标模型的预测准确率。

2、为了达到上述目的,本发明是通过以下技术方案实现的:

3、本发明是一种基于混合增强对比的黑盒无监督领域自适应方法,包括如下步骤:

4、步骤1、将每个目标域样本输入黑盒源域,获得源域预测,代表样本属于源域中每个类的概率。根据源域预测计算每个类别的原型样本和学习难度阈值;

5、步骤2、将每个目标域样本输入目标模型,计算目标模型输出的互信息熵和与源域预测的相对熵作为蒸馏损失;

6、步骤3、计算并存储每个样本与类原型样本特征之间的距离作为非线性预测,增加早期学习正则化项,配合蒸馏损失初始化目标模型,迭代更新样本特征以保留模型训练早期的易学习特征;

7、步骤4、根据步骤1得到的源域预测以及步骤3得到的非线性预测计算伪标签,根据步骤1得到的学习难度阈值为目标样本筛选置信的非同类样本,将两者按相等比例混合增强后重新获得特征充当混合负样本;

8、步骤5、根据类原型特征及步骤4得到的混合负样本特征计算混合增强对比损失,目的是使得每个目标样本与类原型近,与其他类原型和混合负样本远;

9、步骤6、随机选择目标样本对按0.25和0.75的比例进行混合数据增强,根据其在目标模型的输出和其混合后的伪标签计算交叉熵;

10、步骤7、整体损失计算梯度,反向传播,迭代更新网络参数、类原型特征、学习难度阈值直至损失收敛,对目标域数据样本进行预测得到预测标签,与目标域数据样本的真实标签比较,对于每一类计算出该类的平均分类准确率作为度量结果。

11、进一步的,在步骤1中通过黑盒源域模型的输出计算每个类的原型样本和学习难度阈值,如下所示,

12、

13、

14、其中表示目标域样本;fsk表示源域模型预测第k类的概率;为超参数。

15、进一步的,在步骤2中构造了蒸馏损失,通过最小化蒸馏损失来更新目标模型,蒸馏损失由相对熵和互信息熵组成,定义如下:

16、

17、

18、lwarm=lkd-lim

19、其中dkl表示相对熵,ft表示目标模型,h(p)=-∑ipilogpi表示自信息熵。

20、进一步的,在步骤3中通过早期学习正则化项来正则化模型训练过程,保留模型早期记忆的具有正确标签的干净样本,防止噪声数据影响。储存器用于记录每个样本的非线性预测,并通过动量策略基于新的预测进行更新,非线性预测、动量策略和早期学习正则化项定义如下:

21、

22、

23、

24、其中l2()为l2范式,σ为softmax函数,表示类原型样本,oi表示样本在当前模型的中的非线性预测,β为超参数。

25、进一步的,目标模型的线性和非线性预测均有其局限性,在步骤4中综合考量两者获得伪标签,通过伪标签为目标样本筛选置信的非同类样本进行混合增强,定义如下:

26、

27、

28、其中xi表示与第i个样本拥有相同伪标签且置信度大于学习难度阈值的目标样本集合,将第i个样本与集合中的每个样本进行混合增强,获得增强后样本特征作为混合增强对比负样本,混合增强定义如下:

29、

30、mixλ(a,b)=λa+(1-λ)b

31、进一步的,在步骤5中,最小化目标域样本的infonce损失函数,其中样本特征作为锚点,类原型特征作为正样本,其他类原型特征及混合增强样本特征作为负样本,同时为了减小不确定数据对模型训练的影响,使用其置信度作为权重系数,定义如下:

32、

33、其中zi表示锚样本特征,zd表示当前类原型特征,zk表示各个类原型特征,zuk表示混合负样本特征,k表示类别数,表示该样本对应的混合负样本数量,wi表示置信度,τ表示温度系数,zi、zd、zk和zuk均经过归一化操作。

34、进一步的,在步骤6中,随机选择目标域数据样本对进行混合数据增强,根据其混合后的伪标签计算交叉熵,通过最小化交叉熵再次优化目标模型,交叉熵定义如下:

35、

36、其中lce表示交叉熵损失。

37、进一步的,在步骤7中,计算整体损失,反向传播,迭代更新网络参数,整体损失定义如下:

38、l=lkd-lim+αlelr+γlunicon+δlmix

39、其中α、γ和δ为超参数。

40、本发明的有益效果是:本发明提出了一种基于混合增强对比的黑盒源域无监督领域自适应方法,针对源域数据与模型参数均不可获得的情况,通过构造蒸馏损失使模型输出靠近源模型输出的同时鼓励标签分布均匀,防止出现类不平衡现象;通过添加早期正则化项,有效地防止噪声数据造成的错误积累;通过混合增强对比和随机混合数据增强优化目标数据的类间和类内结构,更准确地识别数据特征。相比于其他模型本发明在保证更好泛化性能的同时具备较强的安全性及隐私保护能力。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1