一种基于孪生知识蒸馏与自监督学习的小样本分类方法

文档序号:29570809发布日期:2022-04-09 03:51阅读:646来源:国知局
一种基于孪生知识蒸馏与自监督学习的小样本分类方法

1.本发明属于计算机视觉领域,特别地涉及到了知识蒸馏技术与小样本分类任务。


背景技术:

2.小样本学习(few-shot learning,简称小样本)是指让机器学习模型在学习了一定类别的大量数据后,对于新的类别,只需要少量的样本就能快速学习。小样本学习中的分类问题主要指n-way k-shot问题,它是指:在训练阶段,在训练集中随机抽取n个类别,每个类别k个样本(共n
×
k个数据)作为模型的支持集输入,再从这n个类别中剩余的数据抽取q个样本作为模型的问题集,即要求模型从n
×
k个数据中学会如何区分这n个类别。
3.目前的小样本方法可以分为三类:基于梯度的方法;数据增强方法;基于度量的方法。
4.基于梯度的方法主要指基于元学习的方法,它是指在训练阶段,通过学习一系列的n-way k-shot问题来训练模型,然后将这个模型用于解决新的小样本任务。
5.数据增强方法是指通过生成目标类别额外的数据集来解决数据集的稀缺性,这些方法使用从基类学习的变化的模型或者直接使用对抗生成网络来从目标类别的少量样本去生成数据。
6.基于度量的方法是现在小样本任务中表现最好的方法,它旨在学习能够比较不同样本间相似性的距离度量函数。本发明的方案基于此方法提出。


技术实现要素:

7.本发明旨在提升通过特征提取器提取到的特征表示的质量,进而提升模型在小样本下游分类任务上的性能,同时并不增加网络的复杂性,主要在训练时保留特征表示间的区分信息来平滑特征表示,以便模型在目标类别上可以获得更好的泛化能力。
8.为了实现这个目的,本发明提出了一种孪生知识蒸馏网络,它正则化同一类别中一对输入样本通过同一个网络的分类预测,使它们保持一致,并且在此基础上加入自监督学习,因而增强了特征提取器的泛化性。同时,模型给予难样本更多的注意力,进一步提升模型的泛化能力。
9.本发明采用的技术方案为:
10.一种基于孪生知识蒸馏与自监督学习的小样本分类方法,包括以下步骤:
11.s1、获取用于训练孪生知识蒸馏网络模型与测试小样本分类任务的图像数据集;
12.s2、建立由特征提取网络和分类器网络构成的孪生知识蒸馏网络模型,并使用孪生知识蒸馏与自监督学习相结合的方法训练所述孪生知识蒸馏网络模型;
13.s3、将训练好的孪生知识蒸馏网络模型应用于小样本分类任务。
14.进一步的,步骤s1中,获取的用于训练孪生知识蒸馏网络模型的图像数据集包括c
base
个基类,每个基类包括满足预设数据量的大数量带标签样本;获取的用于测试小样本分类任务的图像数据集包括c
target
个测试类别,每个测试类别包括少量样本;且基类与测试
类别在标签空间中不相交,
15.进一步的,步骤s2具体包括:
16.s21、从用于训练孪生知识蒸馏网络的图像数据集中随机采样一个批量的图像样本其中批量大小n
bs
预先给定;
17.s22、对于基类中的一个样本x,通过特征提取网络与分类器网络得到它的类别预测向量,即z=[z1,...,zi,...,zm],再经过softmax分类器得到它属于第i类的预测概率
[0018][0019]
其中,τ>0是温度缩放参数,它控制softmax输出的归一化概率分布的平滑性;m为类别个数;
[0020]
对于多分类任务,目标损失函数为
[0021]
l
ce
(x)=h(y,p(x))
[0022]
其中,h表示交叉熵损失,p(x)是输入x的预测的概率分布,y是类别真实标签;
[0023]
s23、将相同类别的一对样本作为所述孪生知识蒸馏网络模型的输入,得到各自的概率分布ps(xi)与ps(xj),让p
t
(x)=(ps(xi)+ps(xj))/2作为教师,分别去蒸馏两个学生分布;同时,两个学生概率分布也要受类别真实标签的监督
[0024]
l
twinkd
(xi,xj)=h(αy+βp
t
(x),ps(xi))+h(αy+βp
t
(x),ps(xj)),
[0025]
其中,α和β是平衡教师监督与类别真实标签监督的两个超参数,并且有α+β=1;p
t
(x)不进行梯度传播;
[0026]
s24、使用自监督学习中的2d图像旋转来增强数据,旋转角度分别为r={90
°
,180
°
,270
°
},由于样本本身可以看作0
°
旋转,因此合起来的旋转集可以表示为r

={0
°
,90
°
,180
°
,270
°
};对于一个输入样本,先创建它的三个旋转图像其中表示把样本xi旋转r角度;然后用特征提取器f提取它们的特征表示,之后正则化输入样本对和它们的增强样本的特征表示
[0027][0028]
分类器网络g把特征表示映射入标签空间来预测标签,之后从平均概率值蒸馏知识到旋转的样本;相应的自监督知识蒸馏定义为
[0029][0030]
其中p
t
(x)是输入样本对预测概率ps(xi)与ps(xj)的均值;
[0031]
同时,由输入样本与它的旋转样本提取到的logit向量g(f(x)),旋转分类器去预测它们的旋转角度标签r;因此自监督损失公式为
[0032][0033]
其中q(x)=[q1(x),...,qr(x)]是旋转预测概率向量,通过下式获得
[0034][0035]
其中,u表示样本x旋转向量;
[0036]
综上所述,本网络的整个优化目标为
[0037][0038]
s26、根据得到的优化目标,使用带动量的随机梯度下降优化器,以及反向传播算法训练所述孪生知识蒸馏网络模型;
[0039]
s27、重复步骤s21至s26直至模型收敛。
[0040]
进一步的,步骤s3具体包括:
[0041]
s31、给定一个n-way k-shot分类任务,支持集是s,首先通过特征提取网络,计算各个类别的视觉原型
[0042][0043]
其中,c表示某个类别,sc和|sc|是类别c的支持集和支持集中样本数量;
[0044]
s32、对于查询集中的测试样本x
t
,它属于类别c的概率是
[0045][0046]
其中d是相似性度量函数;最终,根据测试样本属于n个类别的概率大小来预测到底属于哪个类别,概率最大的即为预测类别。
[0047]
本发明的基于孪生知识蒸馏与自监督学习的小样本分类方法,有以下优点:
[0048]
首先,提出了有效且容易实施的框架twinkd来解决小样本任务。twinkd在优化目标中加入kl散度,并且让同类别的输入图像对经过相同网络再相互蒸馏来产生一致的特征表示,这样可以使每个样本的概率分布被它同对样本平滑一下,以获得更好的泛化性。twinkd的优点在于不需要预先训练好的模型,因此效率更高也更节省计算资源。
[0049]
其次,在理论上证明了twinkd给予难样本更多的注意力,并且通过增强难样本数据来学习更丰富的特征表示,表现在把ssl加入twinkd框架中。这样模型会对难样本给予更多的注意力,进一步提升模型的泛化能力,并且它比现存的kd更灵活效率也更高。
[0050]
本发明的基于孪生知识蒸馏与自监督学习的小样本分类方法,在3个小样本分类任务基准(miniimagenet,tieredimagenet和cifar-fs)上均有很好的表现,证明了本发明在性能上的有效性与优越性。
附图说明
[0051]
图1为本发明的基于孪生知识蒸馏与自监督学习的小样本分类方法的流程示意图。
具体实施方式
[0052]
为了使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对
本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
[0053]
相反,本发明涵盖任何由权利要求定义的在本发明的精髓和范围上做的替代、修改、等效方法以及方案。进一步,为了使公众对本发明有更好的了解,在下文对本发明的细节描述中,详尽描述了一些特定的细节部分。对本领域技术人员来说没有这些细节部分的描述也可以完全理解本发明。
[0054]
参考图1,在本发明的较佳实施例中,基于孪生知识蒸馏与自监督学习的小样本分类方法,包括以下步骤:
[0055]
首先,获取用于训练孪生知识蒸馏网络模型与测试小样本分类任务的图像数据集。
[0056]
其中,获取的训练孪生知识蒸馏网络模型的数据集共有c
base
个基类,每个类别都有大量带标签样本。测试小样本分类任务的数据集共有c
target
个类别,每个测试类别仅有少量样本。基类与测试类别在标签空间中是不相交的,即具体符号表示为{(x1,y1),(x2,y2),...,(xn,yn)},其中n为数据集中的图像总数,xi和yi分别表示第i张图像及其对应的类标签,yi∈{1,...,m},m表示类别总数。
[0057]
然后,建立由特征提取网络和分类器网络构成的孪生知识蒸馏网络模型,并使用孪生知识蒸馏与自监督学习相结合的方法训练所述孪生知识蒸馏网络模型。具体包括以下几步:
[0058]
第一步,在训练深度神经网络时采用批量处理的方式,首先从图像数据集中随机采样一个批量的图像样本其中批量大小n
bs
预先给定。
[0059]
第二步,对于基类中的一个样本x,通过特征提取网络与全连接层得到它的logit向量,即z=[z1,...,zi,...,zm],再经过softmax分类器得到它属于第i类的预测概率
[0060][0061]
其中,τ>0是温度缩放参数,它控制softmax输出的归一化概率分布的平滑性;m为类别个数。
[0062]
对于分类任务,目标损失函数为
[0063]
l
ce
(x)=h(y,p(x))
[0064]
其中,h表示交叉熵损失,p(x)是输入x的预测的概率分布,y是类别真实标签。
[0065]
第三步,本发明使用孪生知识蒸馏模式(简称twinkd),让输入同一个模型的同类别样本对相互蒸馏。具体来讲,本发明只训练一个模型,让相同类别的一对样本作为输入,然后它们经过相同的网络,得到各自的概率分布ps(xi)与ps(xj),让p
t
(x)=(ps(xi)+ps(xj))/2作为教师,分别去蒸馏两个学生分布。当然,同时,两个学生概率分布也要受类别真实标签的监督
[0066]
l
twinkd
(xi,xj)=h(αy+βp
t
(x),ps(xi))+h(ay+βp
t
(x),ps(xj))
[0067]
其中,α和β是平衡教师监督与类别真实标签监督的两个超参数,并且有α+β=1。p
t
(x)不进行梯度传播。
[0068]
第四步,固定超参数β,优化目标l
twinkd
对logit向量zi的梯度为
[0069][0070]
经推导,可得
[0071][0072]
其中γi表示不正确分类的概率。
[0073]
当样本xi是难以分类的样本时,即γi大而小,则缩放因子大,意味着模型会给予难样本更多的注意力,反之亦然。
[0074]
第五步,本发明使用自监督学习(简称ssl)策略对难样本进行数据增强,即把ssl加入到twinkd中来使得特征表示多样化。
[0075]
使用ssl中的2d图像旋转来增强数据,旋转角度分别为r={90
°
,180
°
,270
°
}。由于样本本身可以看作0
°
旋转,因此合起来的旋转集可以表示为r

={0
°
,90
°
,180
°
,270
°
}。对于一个输入样本,先创建它的三个旋转图像其中表示把样本xi旋转r角度。然后用特征提取器f提取它们的特征表示,之后正则化输入样本对和它们的增强样本的特征表示
[0076][0077]
分类器网络g把特征表示映射入标签空间来预测标签,之后从平均概率值蒸馏知识到旋转的样本。相应的自监督知识蒸馏定义为
[0078][0079]
其中p
t
(x)是输入样本对预测概率ps(xi)与ps(xj)的均值。
[0080]
同时,由输入样本与它的旋转样本提取到的logit向量g(f(x)),旋转分类器去预测它们的旋转角度标签r。因此自监督损失公式为
[0081][0082]
其中q(x)=[q1(x),...,qr(x)]是旋转预测概率向量,通过下式获得
[0083][0084]
其中,u表示样本x旋转向量。
[0085]
综上所述,本网络的整个优化目标为
[0086][0087]
第六步,根据得到的总的损失函数,使用带动量的随机梯度下降优化器,以及反向
传播算法训练深度神经网络。
[0088]
最后,重复上述步骤直至模型收敛。
[0089]
最后,将训练好的孪生知识蒸馏网络模型应用于小样本分类任务。
[0090]
给定一个n-way k-shot分类任务,支持集是s,首先通过特征提取网络,计算各个类别的视觉原型
[0091][0092]
其中,c表示某个类别,sc和|sc|是类别c的支持集和支持集中样本数量。
[0093]
对于问题集中的测试样本x
t
,它属于类别c的概率是
[0094][0095]
其中d是相似性度量函数,本发明中使用的是cosine相似性函数。最终,根据测试样本属于n个类别的概率大小来预测到底属于哪个类别,概率最大的为预测类别。
[0096]
以上所述仅为本发明的较佳实施例而已,并不用以限制本发明,凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,均应包含在本发明的保护范围之内。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1