一种基于双重不平衡场景下的无监督域适应方法

文档序号:29943750发布日期:2022-05-07 15:23阅读:146来源:国知局
一种基于双重不平衡场景下的无监督域适应方法

1.本发明属于机器学习技术领域,尤其涉及一种基于双重不平衡场景下的无监督域适应方法。


背景技术:

2.大数据时代的来临使数据的产生速度不断加快,数据的体量发生了巨大增长,机器学习凭借强大的数据处理能力得到了越来越多的关注。数据的快速增多使得机器学习与深度学习可以依赖更多的数据,持续不断地训练和更新模型,模型的性能和适用性也越来越好,机器学习技术已经在许多实际应用中取得了巨大成功,但在某些现实场景中仍然存在一定的局限性;传统的机器学习需要利用足够多的有标记数据进行训练才能获得分类性能较好的模型,这就产生了机器学习领域的一个新问题,即如何利用有限的有标记数据获得一个较好的泛化模型,从而对无标记数据进行正确的预测。
3.迁移学习应运而生,迁移学习的含义是运用不同领域的数据之间存在的关联特性,将曾经在一个领域学得的信息用到一个全新的另一个不同的领域中去。两个领域之间的相似度越高,就越容易进行迁移,相反就越难进行迁移,容易起到反作用,产生负迁移的现象。迁移学习包括源域(source domain)和目标域(target domain)这两个领域,其中,源域是含有大量有标记数据的领域,是被迁移的对象,而目标域是只有无标记数据的域或者只有少量有标记数据的域,是需要对领域中的数据进行标签预测的数据域,是迁移知识的应用对象。在减小源域和目标域数据分布差异的同时,学习源域的知识结构或标记信息并将其运用到目标域,使得学习的模型能够对目标数据进行正确的预测,从而完成迁移学习,这种方法统称为无监督域适应学习,大致可以分为三类:基于距离的方法,基于对抗的方法和自训练的方法。通常情况下会假设源域和目标域中每个类别下的数据比例是相对平衡的,并且不同领域中的平衡比例是相似的。
4.然后,真实场景中的原始数据通常是不平衡的。这种不平衡通常发生在每个域内,并可能进一步导致领域间存在不同的类别比例或不平衡比率,称为跨域不平衡。实际应用中无监督域适应学习经常会遇到这种双重不平衡的场景,分类边界可能会偏向于源域中的多数类,然而这些类在目标域中是少数类,因此,它会导致对大多数目标域样本的错误分类,甚至是负迁移。目前主流的做法是在基本域适应方法的基础上结合使用数据重加权或者生成样本的策略,数据重加权方法试图通过对少数类样本进行过采样或对多数类样本进行低采样来平衡数据分布,生成样本的方法是通过生成靠近少数类样本的方法进行数据扩充,使得数据整体分布变得相对平衡。
5.但经检验发现,在双重不平衡场景下,普通的迁移学习域适应方法不能达到很好的效果,只考虑减少领域间的特征差异,可能会产生负迁移降低模型在目标域上的性能;数据重加权的方式在对多数类样本进行低采样时可能会丢失多数类的信息,而在加大少数类样本权重时存在对少数类样本过拟合的风险;更重要的是,因为存在跨域间的不平衡,源域的标签分布和目标域的标签分布甚至可能是相反的,模型在经过域对齐之后会对目标域中
的多数样本分类错误,达不到预期的效果。


技术实现要素:

6.为解决上述技术问题,本发明提供了一种基于双重不平衡场景下的无监督域适应方法,通过类对比知识迁移移、类关联知识迁移和判别型的特征对齐,在域间迁移能容忍不平衡的知识;源域数据样本和目标域数据样本都被输入到一个共同的特征提取器中,通过减少判别型的最大均值差异,使得跨域的相同类对齐和不同类之间分离;采用类对比损失来限制特定类样本的分类得分应该高于其他类样本的分类得分,减少了对标签分布的依赖从而减轻数据不平衡对模型的影响;同时从源域数据样本中学习类之间的相关性向量,将类间的关联知识也迁移到目标域中,进一步提高模型在目标域上的泛化性能。
7.本发明所述的一种基于双重不平衡场景下的无监督域适应方法,包括如下步骤:
8.步骤1、构造类对比损失,对于每一类,最大化该类样本和其他类样本之间的得分预测差异,使用可导的指数损失作为类对比损失的代理损失;
9.步骤2、对于源域数据样本使用类对比损失和交叉熵损失预训练一个分类模型,分类模型由特征提取器和分类器构成,计算源域数据样本的类关联向量;
10.步骤3、使用步骤2得到的多层的特征提取器和随机初始化的分类器构造解决双重不平衡场景的网络模型;
11.步骤4、对源域数据样本的每一类进行采样构造训练数据,将目标域数据样本输入到步骤3得到的网络模型中得到预测的伪标签,再根据伪标签对每一类采样构造目标域的训练数据,将两组数据一起输入到步骤3得到的网络模型中,得到源域特征和目标域的特征;
12.步骤5、使用源域特征,目标域特征,源域的真实标签和目标域的伪标签,构造判别型的最大均值差异,最小化类内的紧致性同时最大化全局距离,以加强跨域的相同类的特征对齐和不同类间的特征差异;
13.步骤6、对于源域数据样本使用类对比损失和交叉熵损失,对于目标域数据样本选择置信度高的样本使用类对比损失并约束分类器输出和对应类的关联向量之间的一致性;
14.步骤7、整体损失计算梯度,反向传播,迭代更新网络参数直至损失收敛,对目标域数据样本进行预测得到预测标签,与目标域数据样本的真实标签比较,对于每一类计算出该类的分类准确率,再计算所有类的平均分类准确率作为度量结果。
15.进一步的,在步骤1中对于样本对定义了k个类的类对比标签,一组样本如果所属不同类别则类对比标签为1,否则为0,使用类对比标签构造了类对比损失,使用可导的指数损失作为类对比损失的代理损失,来拟合类对比损失进行计算优化,并使用因式分解进行化简。
16.进一步的,步骤2中,对于源域的每一类,计算特定类中所有样本经过预训练模型输出的概率向量经过softmax之后的平均值作为这一类的关联向量,记做c
(k)
,类关联向量上反映了每个类别与其他类别的相关性,如果l、m≠k表示第k类相对于第m类来说更相近于第l类。
17.进一步的,在步骤5中构造判别型最大均值差异距离dmmd,传入步骤4中得到的源
域特征和目标域特征以及源域的真实标签和目标域对应的伪标签最小化dmmd损失,dmmd定义如下:
[0018][0019]
其中xs代表所有源域数据样本,x
t
代表所有目标域数据样本,表示源域数据样本属于第k类的概率,表示目标域数据样本属于第k类的概率,直接通过源域的标签获得,如果源域样本是第k类则为1,否则为0;为目标域样本经过分类器所输出的概率向量为源域样本经特征提取器所得到的特征,为目标域样本经特征提取器所得到的的特征。
[0020]
进一步的,在步骤6中,对源域数据样本使用交叉熵损失和类对比损失,保证模型能够对源域的数据正确分类,对于目标域数据样本根据阈值挑选置信度高的样本预测其伪标签,阈值会随着训练过程逐渐提高;将挑选出的目标域数据样本经过模型输出的概率向量pi通过放缩的softmax层,最小化类关联损失计算其与对应类的关联向量的一致性,其中表示目标域样本预测的伪标签,代表挑选出的所有目标域数据样本,代表每一个数据样本xi根据预测的伪标签所对应的类关联向量,将源域中类别间的关联信息迁移到目标域,使目标域也学习到类之间的相关信息。
[0021]
本发明所述的有益效果为:本发明提出了一种解决双重不平衡场景下的无监督域适应方法,针对域内不平衡和跨域不平衡,该模型学习到类对比知识和类关联知识,所提出判别型的域间对齐方式也能更好的减小跨域的每个类之间的距离,从而将源域中学到的信息迁移到目标域,减少模型受数据不平衡程度的影响,相比于其他模型能达到更高的分类准确率,使得模型在这种更接近真实场景的情况下具有更好的泛化性能。
附图说明
[0022]
为了使本发明的内容更容易被清楚地理解,下面根据具体实施例并结合附图,对本发明作进一步详细的说明。
[0023]
图1为方法流程图;
[0024]
图2为网络模型总体架构图;
[0025]
图3为实例的数据分布图;
[0026]
图4为本发明与其他算法的结果比较。
具体实施方式
[0027]
一种基于双重不平衡场景下的无监督域适应方法,如图1所示,包括以下步骤:
[0028]
一、数据处理
[0029]
在模型训练前,将用户提供的图片数据通过改变大小、随机裁剪等预处理方式统一成网络模型输入所要求的格式,同时源域数据是带有标签的,目标域数据没有标签的。
[0030]
二、模型训练
[0031]
这个阶段大致可以分为两个过程,即源域模型预训练、预测模型训练。
[0032]
源域模型预训练是使用源域数据训练的阶段,为了尽可能的学习到源域中的知识,计算每个类的关联向量,具体为:构造类对比损失,最大化特定类样本和其他类样本之间的得分预测差异,使用指数损失作为代理损失便于进行优化;其中,auc度量考虑了分类器对于正例和负例的分类能力,在数据不平衡的情况下,依然能够对分类器作出合理的评价。受此启发提出了类对比损失,对于k个类的多分类任务,采用一个具有k维概率输出的分类器,每一维描述了单个样本属于每一类的概率,对于第k类来说,如果有xi∈xk和样本xi属于第k类的分类得分fk(xi)应该高于xj在第k类的分类得分fk(xj);对于每个样本对(xi,xj),定义第k类的类对比标签为构造出类对比损失为构造出类对比损失为其中θf是特征提取器f的参数,θc是多分类器c的参数;该损失是非凸的不能直接使用,又引入指数损失作为代理损失,同时为了加速计算减少时间复杂度,通过因式分解方案化简,经验风险可以进一步表述为以下形式:
[0033][0034]
其中,β为指数损失中自带的参数,计算损失时只需先分别计算(ai)和(bi),然后执行乘法操作即可计算出类对比损失,计算损失时只需先分别计算(ai)和(bi),然后执行乘法操作即可计算出类对比损失。
[0035]
在步骤2中,对于源域的每一类,计算该类中所有样本经过预训练模型输出的概率向量经过softmax之后的平均值作为这一类的关联向量,记做c
(k)
,类关联向量反映了每个类别与其他类别的相关性,如果l、m≠k表示第k类相对于第m类来说更相近于第l类。
[0036]
预测模型训练阶段是训练真正要使用的模型,模型架构如图2所示,使用预训练模型的特征提取器作为预测模型特征提取器的初始化,使用判别型的特征对齐,加强源域和目标域中相同类的特征对齐和不同类间的特征差异,将类对比知识和类关联知识迁移到目标域中,减少模型受数据不平衡程度的影响,使模型在目标域上也有好的泛化性能;具体为:构造解决双重不平衡场景的网络模型,由多层的特征提取器和分类器组成,使用预训练好的特征提取器和随机初始化的分类器;对源域数据的每一类进行采样构造训练数据,将目标域数据输入到模型中得到预测伪标签,再根据伪标签对每一类采样构造目标域的训练
数据,将两组数据一起输入到模型中,得到源域特征和目标域的特征。
[0037]
使用源域特征,目标域特征,源域的真实标签和目标域的伪标签,构造判别型的最大均值差异,减少源域和目标域之间的差异;提出判别型最大均值差异距离dmmd,传入源域特征和目标域特征最小化dmmd损失,能够跨域最小化同一类的类内距离,并且最大化全局的整体距离,达到加强类内的特征对齐和不同类间的特征差异;dmmd定义如下:
[0038][0039]
其中xs代表所有源域样本,x
t
代表所有目标域样本,表示源域样本属于第k类的概率,表示目标域样本属于第k类的概率,直接通过源域的标签获得,如果源域样本是第k类则为1,否则为0,为目标域样本经过分类器所输出的概率向量为源域样本经特征提取器所得到的特征,为目标域样本经特征提取器所得到的的特征。
[0040]
对于源域数据样本使用类对比损失和交叉熵损失,对于目标域数据样本选择置信度高的样本使用类对比损失,预测其伪标签和之前计算出的对应类的关联向量进行约束;整体损失计算梯度,反向传播,迭代更新网络参数直至损失收敛,对目标域数据样本进行预测得到预测标签,与目标域数据样本的真实标签比较,对于每一类计算出该类的分类准确率,再计算所有类的平均分类准确率作为度量结果。
[0041]
下面以digits数据集为例,说明本发明实施例方法的处理流程:
[0042]
如图3所示,左边是源域数据样本的标签分布,右边是目标域数据样本的标签分布,两个域都有同样的10个类,源域数据是有标签的而目标域数据是没有标签的。
[0043]
1.将源域数据样本和目标域数据样本都转成大小为32*32的、三通道的图片;
[0044]
2.使用lenet模型作为源域预训练的模型,输入源域数据和对应的标签,利用模型输出的概率向量和真实标签计算类对比损失和交叉熵损失,更新预训练模型直至损失收敛,保存预训练模型的特征提取器;
[0045]
3.再次将源域数据样本和标签输入到训练好的预训练模型中,按照标签类别,计算特定类中所有样本经过预训练模型输出的概率向量经过softmax之后的平均值作为这一类的关联向量,如第1类的关联向量在第1维的值最大,第6维的值大于第2维的值,说明第1类与第6类之间更相似,与第2类更不相似;
[0046]
4.使用第2步中保存的预训练好的特征提取器和随机初始化的分类器构造真正使用的网络模型;
[0047]
5.将所有目标域数据输入到网络模型中得到目标域数据样本的伪标签;
[0048]
6.按照类别,从源域数据样本中每类选取12个样本,对目标域数据样本根据伪标签也在10个类中每个类选取12个样本构成一组训练数据输入网络模型得到对应的源域特征和目标域特征;
[0049]
7.对步骤6中得到的源域特征和目标域特征使用判别性的最大均值差异距离dmmd进行拉近,将同一类的源域特征和目标域特征进行拉近同时将两个域的整体特征进行拉远,计算dmmd损失;
[0050]
8.初始阈值设置为0.6,挑选模型输出的概率向量的最大值大于阈值的目标域样本为置信度高的样本,这个阈值在训练过程中逐渐增大至0.85;
[0051]
9.将所有的源域数据样本经过模型输出的概率向量和真实标签一起计算类对比损失和软标签损失;
[0052]
10.对挑选出来的高置信度的目标域样本,使用其经过模型输出的概率向量也计算类对比损失,同时计算和对应类的关联向量之间的关联损失;
[0053]
11.整体损失计算梯度,反向传播更新网络参数,每20轮更新一次所有目标域数据的伪标签,每50轮计算一次目标域数据每一类的平均准确率,直至损失收敛。
[0054]
保存最优的目标域模型,输出对目标域数据样本预测的标签。
[0055]
如图4所示,本方法称为titok,在digits数据集下的三个迁移任务中相较以前的方法都达到了更高的平均分类准确率,三个任务的平均性能上也是达到了最高值,相比于2021年的最新方法coal,三个任务上的平均性能高出了1.34个百分点,相比之前的一些其他方法,本方法在性能上有了显著提高。
[0056]
以上所述仅为本发明的优选方案,并非作为对本发明的进一步限定,凡是利用本发明说明书及附图内容所作的各种等效变化均在本发明的保护范围之内。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1