面向易混淆类的动态类间距域自适应方法

文档序号:31706387发布日期:2022-10-01 11:31阅读:127来源:国知局
面向易混淆类的动态类间距域自适应方法

1.本发明涉及机器学习领域的域自适应技术,具体涉及一种面向易混淆类的动态类间距自适应方法。


背景技术:

2.随着深度学习的不断发展,深度模型在分类预测等实际应用的性能得到极大提高,但训练出有效的深度模型需要大量的有标记数据。实际的场景中,大量的有标记数据难获取,需要花费大量的人力和物力。因此,怎样找到一个有效的方法,来解决标记数据缺失的问题就显得尤为关键。
3.在标记数据缺失的情况下,域自适应方法是有效的解决方案。域自适应是迁移学习的一个分支内容,其目的是希望能够借助源域中丰富的标记数据训练出一个有效的模型,并且能将这个模型有效的运用到缺乏标记数据的目标域样本,使得深度模型能够对目标域样本进行分类预测。
4.域自适应主要解决的是如何减小源域和目标域样本分布之间的差异,在此基础上提出了很多域自适应方法,具有代表性的方法很多都是基于对抗网络的域对齐方法,这些方法是借鉴了生成对抗网络(gan,generative adversarial networks),的基本思想,目的都是让源域和目标域的样本分布趋于一致,使得源域训练出的模型能够很好地应用到目标域中。但是,这些基于域对抗思想的方法都没有能够很好地解决目标域中易混淆类的分类问题,这些易混淆类的样本在特征空间中的分布很接近,传统的域对齐方法很难区分易混淆类的类边界,这样使得域自适应分类在面向易混淆类时分类效果很差,使得基于对抗思想的域自适应方法用于图像分类时会出现过于粗糙的问题。
5.传统的域对齐方法通过监督学习增加源域同类样本的汇聚程度,即减小同类样本的距离,然后,通过域对齐方法间接完成对目标域样本的类汇聚,这种方法对于目标域中一般类样本的分类效果很好,但是在面向目标域易混淆类时,因为易混淆类的类中心距离很近,目标域的类汇聚程度和源域相比,还有较大差距,所以目标域中易混淆类的边界往往重叠在一起,造成分类精度下降。检索发现公开号为cn112819098a的中国专利公开了一种基于三元组和差额度量的域自适应方法,该方法从目标域随机抽取样本形成目标域 batch,输入特征提取器获得样本特征,将样本特征输入多分类器和多二分类器,然后利用三元组损失函数筛选构建源域batch,训练多分类器和多二分类器,最后将目标域batch和源域batch送入域对抗网络极性域对其操作,在度量空间中进行一个聚类的操作。
6.综上所述,现有技术中对于基于对抗网络的域自适应方法如何有效解决在图像分类过程中目标域易混淆类样本的分类问题尚没有公开的披露。


技术实现要素:

7.本发明的目的在于,针对现有技术存在的缺陷,提出一种面向易混淆类的动态类间距域自适应方法,主要用于解决基于生成对抗网络的域自适应方法中,如何在特征空间
中拉开目标域类的样本特征中心距离,即增加类间距离的基础上,进一步增加目标域易混淆类的类间距离。
8.为了达到以上目的,本发明提供一种面向易混淆类的动态类间距域自适应方法,包括以下步骤:
9.步骤1、深度网络的训练过程中每个训练batch包括目标域 batch和源域batch;
10.步骤2、从目标域随机抽取样本形成目标域batch,并送入特征提取器提取样本特征,将样本特征送入多分类器进行熵最小化,同时将样本特征送入多二分类器,根据多二分类器输出最大值和次大值的差异判定出k个临界样本和对应的k对相似类,建立易混淆类集合;
11.步骤3、在建立的易混淆类集合中,根据易混淆类对在多二分类器的输出最大值和次大值的差值,计算出在一个目标域batch中易混淆类对应的类间差值;
12.步骤4、使用提取好的源域batch,送入特征提取器提取特征,将提取的特征送入到多分类器进行监督训练;
13.步骤5、根据源域batch中样本的类别信息和易混淆类的类间距值,计算损失函数,并回溯优化模型参数;
14.步骤6、将源域batch和目标域batch提取到的特征,分别送入域对抗网络中,进行域对齐操作。
15.本发明首先在目标域中计算出样本的多二分类器输出的最大值和次大值的差值,根据差值识别出目标域中的易混淆类,对于易混淆的类,利用角边距损失函数,动态增加源域中该易混淆类的类间距离,再利用域对抗网络进行域对齐,使得源域和目标域的样本分布趋于一致,间接地让目标域易混淆类间的中心距离增加,使得目标域中易混淆类的边界不再相互重叠,从而增加域自适应算法中目标域易混淆类样本的分类精度,提高域自适应算法的总体性能,有效提高了无监督域自适应网络对于目标域相似类样本的判别能力。
16.本发明进一步的采用如下技术方案:
17.所述步骤2中,从目标域中随机抽取样本形成目标域batch,送入特征提取器f提取特征后再送入多分类器cm进行熵最小化,目的是为了降低目标域样本的不确定性。损失函数如下:
[0018][0019]
式中,表示目标域batch的样本数,h(
·
)表示求熵,其目的是为了降低目标域样本的不确定性,使得未标记样本的决策边界通过密度最小的区域,提升样本的聚类效果;
[0020]
将目标域batch提取的特征送入多分类器的同时,也送入到多二分类器cb中,通过cb输出最大值和次大值的差异判定临界样本;临界样本的判定方法如下:
[0021]
对于目标域batch在cb上的输出,将目标域batch中每个样本在cb上输出的最大值与次大值之间的差值定义为分类距离,寻找分类距离最小的前k个样本,并将这k个样本判定为临界样本;
[0022]
记录下每个临界样本最大值和次大值所对应的类(易混淆类),分别将临界样本最
大值所对应的类记为a类,临界样本次大值所对应的类记为b类,针对其易混淆类a和b,建立易混淆类集合存储。
[0023]
上述步骤中,先从目标域的数据集中随机抽取一个batch的样本,送入特征提取器去提取特征并将提取到的特征送入多分类器,最后对于多分类器产生的输出取熵,使得目标域样本的熵尽可能小。在将特征送入多分类器的同时,也将特征送入多二分类器。对于目标域batch通过多二分类器产生对应的输出,求出该batch样本中每个样本在多二分类器输出的最大值与次大值之差,将差值最小的k个样本作为临界样本,也即认定目标域batch的这k个样本靠近分类边界,同时记录下这k个样本所对应的k对相似类,并建立易混淆类集合,也即用集合保存每个临界点在多二分类器输出的最大值和次大值所对应的k对类别。这样,针对一个batch中的易混淆类,有针对性的拉开易混淆类的类间距,直接作用在特征空间中。上述操作能够同时保证类内紧凑性以及类间差异性,提高了分类的精确程度,有效改善了分布在分类边界易混淆样本导致的分类误差。
[0024]
所述步骤3中,根据临界点在多二分类器的输出最大值和次大值的差值所对应的最小分类距离来确定边界阈值,边界阈值β
(a,b)
定义如下:
[0025][0026]
其中,表示目标域batch,表示目标域batch中找到的临界样本,yi表示一个样本在多分类器上输出概率所对应的类别,ya表示一个样本在多分类器上除了输出概率外最大概率所对应的类别, yb表示一个样本在多分类器上除了输出概率外次大概率所对应的类别,表示临界样本在多二分类器cb上的输出值,上带有箭号的目的主要是由于目标域样本没有标签信息,所以目标域样本通过cb输出的预测值为伪标签,表示临界样本被分类为类a的概率,表示临界样本被分类为类b的概率,α0为初始值,μ为恒定系数,表示临界点xi所对应的最小分类距离,越小说明在特征空间中两个类距离越接近,也就是说分类时越易混淆,对取倒数。边界阈值计算公式中加入对数的作用是为了防止过小,导致边界阈值过大。这样,在找到易混淆类之后,就可以通过最小分类距离计算其易混淆程度,并且用边界阈值来衡量找到的易混淆类对的易混淆程度。
[0027]
同时,为了提高边界阈值的计算准确率,降低一次计算的偶然性,对边界阈值采用不同batch求平均的方式,公式如下:
[0028][0029]
其中,t是batch数,表示a和b的两个相似类在第t个batch 的平均值;
[0030]
由于每一个batch中类间差值在训练过程中会呈现较大的波动,需要将其形式转换成sigmod形式,得到边界阈值的归一化形式,公式如下:
[0031][0032]
其中,是第二步中建立好的易混淆类,是在一个batch中易混淆类所对应的类间差值,即在一个batch中易混淆类的易混淆程度计算出的类间差值。
[0033]
上述步骤中,计算易混淆类的集合中记录的易混淆类的易混淆程度,并反映到类间损失中。在这里采取了将易混淆样本在多二分类器上的输出概率差值,即输出的最大值以及次大值之差,求倒数再取对数,然后在结合sigmod函数使计算出的类间差值更可控,并将其作为类间损失的类间差值。
[0034]
所述步骤4中,将源域batch提取的特征分别送入多分类器cm中,利用输出值和真实值之间的差异来优化损失函数损失函数如下:
[0035][0036]
式中,表示源域batch样本数,xi为源域batch的第i个样本,yi为样本xi的真实标签,ly为标准交叉熵误差函数,表示源域batch。
[0037]
上述步骤中,抽取源域样本,送入特征提取器中提取特征,然后将提取到的特征送入多分类器。
[0038]
所述步骤5中,将多分类器的最后一层全连接层的特征向量以及权重向量取出,归一化之后,优化激活函数softmax函数,计算加性角度损失,同时将计算归一化的门限阈值用于计算各个类别对应的权重夹角,计算优化类间损失,公式如下:
[0039]
式中,为源域样本,s为源域(source),i表示源域样本的第i个类,为yi类的特征向量与权重向量在经过正则化之后矢量点乘计算出夹角,m是设定的角度惩罚项,j为源域样本的第j个类,n为所有的类别数,θj为第j个类的权重及其样本特征的点积夹角,为第i个类所对应的权重,wj为第j个类所对应的权重,τ为类间差值的系数,为类间差
值。其中,
[0040]
表示加性角度损失,
[0041]
表示两个类的权夹角,与wj进行点积的结果就是他们夹角的余弦值,因此进行反余弦操作即可得到余弦值对应的夹角。本发明通过公式(1-6)优化类间损失。
[0042]
在上述步骤5中,根据步骤3计算得出的易混淆类的类间差值以及步骤4提取的源域batch,在源域batch中寻找易混淆类,并根据易混淆类的类间差值,计算损失函数,以优化回溯模型。上述操作是在加性角度损失的基础上进行改进,加性角度损失只能使得样本在分类时更加聚类,保证类内紧凑性,而面对在分类边界的易混淆类就没有办法拉开他们的距离。本发明的方法使得分类时易混淆类的类间距有效拉开,相比加性角度损失,易混淆类样本减少,以此来改善分类性能。
[0043]
上述步骤中,抽取多分类器的最后一层全连接层中的特征向量以及权重向量,将其归一化之后,对激活函数softmax进行优化,得到并计算加性角度损失,同时用取得的权重向量与步骤3中得到的类间差值计算类间损失。
[0044]
所述步骤6中,将源域batch和目标域batch提取到的特征送入域对抗网络中,进行域对齐操作,由于梯度反转层的存在,使得特征提取部分能够使域分类器d的loss值增大,从而使得提取到的特征具有混淆d的作用,降低源域和目标域特征分布之间的差异,损失函数如下:
[0045][0046]
式中,和分别是源域和目标域batch的样本数,xi是源域batch的第i个样本,xj是目标域batch的第j个样本,d(.)表示样本属于目标域还是源域,如果是源域d(.)等于1,否则,d(.)等于 0。
[0047]
上述步骤中,将源域batch和目标域batch样本送入特征提取器中,提取到的特征再分别送入域对抗网络中,借助该网络能在很大程度上消除源域和目标域的特征分布差异,帮助实现域对齐。
[0048]
本发明通过目标域样本来寻找易混淆类,然后在源域中有针对性地对这些易混淆类进行拉开的操作,其中并不涉及三元组样本。本发明提出的面向易混淆类的域自适应方法通过优化易混淆类的类间损失以此在特征空间中拉开易混淆类。
[0049]
与现有技术相比,本发明具有以下有益技术效果:
[0050]
1、本发明有效地解决了基于对抗思想的域自适应方法用于图像分类时过于粗糙的问题,针对目标域中易混淆类,优化最后一层全连接层的激活函数softmax函数,利用优化的加性角度损失以及类间损失在特征空间中对特征训练,保证了特征的类内紧凑型以及类间差异性,以保证分类的性能更加精细,提高了分类的精度。
[0051]
2、具有很好的创新性。本发明在源域样本训练时,通过建立易混淆类集合,可以有针对的对易混淆类进行训练,保证了易混淆类能够在训练中更能被损失函数影响拉开距
离。另外,本发明在类间损失的类间差值设计上具有很好的创新性,通过临界点的最小分类距离求倒数,并通过sigmod激活函数来计算类间差值,保证了相似程度越高的两类样本,它们之间的距离能被有效的拉开,同时通过sigmod 函数,使得类间差值始终在一个可控范围内,不会因为类间差值过大,导致类间损失过大,影响整个网络的训练。
[0052]
3、具有较好的适用性。本发明在无监督域自适应图像分类的过程中,有效提高了特征的可迁移性,同时极大提高了模型的泛化能力。
[0053]
4、具有简单性的特点。模型构造简单,物理意义直观,计算复杂度较小。
[0054]
本发明设计出一种能够实现多线程协作的蒙特卡罗算法,该算法在gpu上的加速具有非常重要的应用价值,可以大幅降低投入与开销,开发简单方便,有望解决大搜索空间分子对接速度的问题。
附图说明
[0055]
下面结合附图对本发明作进一步的说明。
[0056]
图1为本发明的训练流程图。
具体实施方式
[0057]
实施例一
[0058]
整个方法过程可以分为分类器训练过程与域对齐过程。具体流程如图1所示。
[0059]
第一步,深度网络的训练过程中每个训练batch包括目标域 batch和源域batch。
[0060]
第二步,首先随机抽取目标域样本构成目标域batch,依次送入特征提取器和多分类器,对输出值进行熵最小化操作,同时将目标域batch样本提取到的特征送入多二分类器,根据输出最大值和次大致的差异确定k个临界样本以及k对相似类,建立易混淆类对集合。
[0061]
先从目标域中随机抽取样本形成目标域batch,送入特征提取器 f提取样本特征后,将样本特征送入多分类器cm进行熵最小化,目的是为了降低目标域样本的不确定性。损失函数如下:
[0062][0063]
式中,表示目标域batch的样本数,h(
·
)表示求熵,其目的是为了降低目标域样本的不确定性,使得未标记样本的决策边界通过密度最小的区域,提升样本的聚类效果。
[0064]
将目标域batch提取的样本特征送入多分类器的同时,也送入到多二分类器cb中,通过cb输出的最大值和次大值的差异判定临界样本,临界样本的判定方法:对于目标域batch在cb上的输出,将目标域batch中每个样本在cb上输出的最大值与次大值之间的差值定义为分类距离,寻找分类距离d最小的前k个样本,并将这k个样本判定为临界样本,同时记录下每个临界样本最大值和次大值所对应的类(易混淆类),分别记为a类和b类。针对其易混淆类a和b 建立易混淆类集合存储。
[0065]
第三步,在建立的易混淆类集合中,根据易混淆类对在多二分类器的输出最大值
和次大值的差值,计算出在一个目标域batch中易混淆类对应的类间差值。
[0066]
根据第二步中建立的易混淆类集合,计算临界样本。为此,可以根据临界点在多二分类器的输出最大值和次大值的差值所对应的最小分类距离来确定边界阈值边界阈值定义如下:
[0067][0068][0069][0070]
其中,表示目标域batch,表示目标域batch中找到的临界样本,yi表示一个样本在多分类器上输出概率所对应的类别,ya表示一个样本在多分类器上除了输出概率外最大概率所对应的类别,yb表示一个样本在多分类器上除了输出概率外次大概率所对应的类别,与类别b相同,表示临界样本在多二分类器cb上的输出值,其上带有箭号的目的主要是由于目标域样本没有标签信息,所以目标域样本通过cb输出的预测值为伪标签,表示临界样本被分类为类a的概率,表示临界样本被分类为类b的概率,α0为初始值,μ为恒定系数,表示临界点xi所对应的最小分类距离,越小说明在特征空间中两个类距离越接近,也就是说分类时越易混淆,为了更容易将它们拉开,对取倒数在训练的过程中更容易分开。边界阈值计算公式中加入对数的作用是为了防止d过小,导致边界阈值过大。同时,为了提高边界阈值的计算准确率,降低一次计算的偶然性,对边界阈值采用不同 batch求平均的方式,公式如下:
[0071][0072]
其中,t是batch数量,表示a和b的两个相似类在第t个batch 的平均。由于每一个目标域batch中类间差值在训练过程中会呈现较大的波动,将其形式转换成sigmod形式,将边界阈值进行归一化,公式如下:
[0073][0074]
其中,是第二步中建立好的易混淆类,是在一个batch中易混淆类对应的类
间差值。
[0075]
第四步,随机抽取一个目标域batch的源域样本作为源域batch,使用提取好的源域batch,送入特征提取器提取特征,再将提取到的特征分别送入多分类器进行监督训练,利用输出值和真实值之间的差异来优化损失函数损失函数如下:
[0076][0077]
式中,表示源域batch样本数,xi为源域batch的第i个样本, yi为样本xi的真实标签,ly为标准交叉熵误差函数,表示源域 batch。
[0078]
第五步,根据源域batch中样本的类别信息和易混淆类的类间距值,计算损失函数,并回溯优化模型参数。即提取多分类器的最后一层全连接层的特征及权重,归一化之后,根据分类信息优化多分类器与特征提取器。
[0079]
将多分类器的最后一层全连接层的特征向量以及权重向量取出,归一化之后,优化激活函数softmax函数,计算加性角度损失,同时将计算好的归一化的门限阈值用于计算各个类别对应的权重夹角,计算优化类间损失,公式如下:
[0080][0081]
式中,为源域样本,s为源域(source),i表示源域样本的第i个类,为yi类的源域样本特征向量与权重向量在经过正则化之后,矢量点乘计算出夹角,m是设定的角度惩罚项,j为源域样本的第j 个类,n为所有的类别数,θj为第j个类的权重及其样本特征的点积夹角,为第i个类所对应的权重,wj为第j个类所对应的权重,τ为类间差值的系数,为类间差值。
[0082]
第六步,将源域batch和目标域batch提取到的特征,分别送入域对抗网络中,进行域对齐操作。
[0083]
将源域batch和目标域batch提取到的特征送入域对抗网络中,进行域对齐操作,由于梯度反转层的存在,使得特征提取部分能够使域分类器d的loss值增大,使得提取到的特征具有“混淆”d的作用,从而降低源域和目标域特征分布之间的差异。损失函数如下:
[0084][0085]
式中,和分别是源域batch和目标域batch的样本数,是源域batch的第i个样本,是目标域batch的第j个样本,d(.)表示样本属于目标域还是源域,如果是源域d(.)等于1,否则, d(.)等于0。
[0086]
采用本发明的方法对office31数据集的六个域进行迁移任务,将其结果与采用其他方法对office31数据集六个域的迁移任务进行对比,见表1。
[0087]
表1
[0088]
accuracy(%)on office-31 for unsupervised domain adaptation(resnet-50)
[0089][0090]
由表1可知,本发明相对其他方法来讲,性能更加优异。
[0091]
本发明针对多对易混淆类进行了操作,通过实验证明,易混淆类经过训练后,易混淆类之间的权重中心有了明显的拉开。
[0092]
除上述实施例外,本发明还可以有其他实施方式。凡采用等同替换或等效变换形成的技术方案,均落在本发明要求的保护范围。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1