本申请涉及计算机视觉技术和人工智能,特别是涉及一种分类模型的训练方法、装置、计算机设备、存储介质和计算机程序产品。
背景技术:
1、随着人工智能技术的快速发展,人工智能被广泛应用在各行各业。以人工智能在图像处理上的应用为例,利用人工智能进行机器学习训练分类模型,能够提高图像分类的效率和精度。
2、其中,分类模型的精度还受限于训练样本,训练样本越多,标注结果越精确,训练得到的分类模型的预测精度也越精确。然而,在实际应用中,训练样本通常表现为长尾类分布。位于头部的一小部分类别含有较多数量的样本,剩下的类别含有较少数量的样本。样本数据据的长尾分布会,会使得模型在数量较多的样本上学习效果更好,而在数量较小的样本上学习效果更差,从而降低了模型的泛化性能,对模型性能造成影响。
技术实现思路
1、基于此,有必要针对上述技术问题,提供一种能够提高模型性能的分类模型的训练方法、装置、计算机设备、计算机可读存储介质和计算机程序产品。
2、第一方面,本申请提供了一种分类模型的训练方法。所述方法包括:
3、获取待训练数据的头部特征和尾部特征;
4、根据所述头部特征中的类别信息,提取所述头部特征的同类特征对;
5、根据所述头部特征的同类特征对,提取与类别无关的特征;
6、利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征;
7、融合所述自适应增广特征和所述尾部特征,得到增广尾部特征;
8、对所述头部特征、所述尾部特征和所述增广尾部特征进行特征提取后并进行分类预测,得到分类结果;
9、根据分类损失和预设尾部增广处理损失,更新所述分类模型的参数,直到所述分类模型收敛时,得到训练好的分类模型。
10、第二方面,本申请还提供了一种分类模型的训练装置。所述装置包括:
11、特征获取模块,用于获取待训练数据的头部特征和尾部特征;
12、同类特征对获取模块,用于根据所述头部特征中的类别信息,提取所述头部特征的同类特征对;
13、特征提取模块,用于根据所述头部特征的同类特征对,提取与类别无关的特征;
14、自适应增广模块,用于利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征;
15、尾部增广模块,用于融合所述自适应增广特征和所述尾部特征,得到增广尾部特征;
16、分类预测模块,用于对所述头部特征、所述尾部特征和所述增广尾部特征进行特征提取后并进行分类预测,得到分类结果;
17、调整模块,用于根据分类损失和预设尾部增广处理损失,更新所述分类模型的参数,直到所述分类模型收敛时,得到训练好的分类模型。
18、第三方面,本申请还提供了一种计算机设备。所述计算机设备包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现以下步骤:
19、获取待训练数据的头部特征和尾部特征;
20、根据所述头部特征中的类别信息,提取所述头部特征的同类特征对;
21、根据所述头部特征的同类特征对,提取与类别无关的特征;
22、利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征;
23、融合所述自适应增广特征和所述尾部特征,得到增广尾部特征;
24、对所述头部特征、所述尾部特征和所述增广尾部特征进行特征提取后并进行分类预测,得到分类结果;
25、根据分类损失和预设尾部增广处理损失,更新所述分类模型的参数,直到所述分类模型收敛时,得到训练好的分类模型。
26、第四方面,本申请还提供了一种计算机可读存储介质。所述计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现以下步骤:
27、获取待训练数据的头部特征和尾部特征;
28、根据所述头部特征中的类别信息,提取所述头部特征的同类特征对;
29、根据所述头部特征的同类特征对,提取与类别无关的特征;
30、利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征;
31、融合所述自适应增广特征和所述尾部特征,得到增广尾部特征;
32、对所述头部特征、所述尾部特征和所述增广尾部特征进行特征提取后并进行分类预测,得到分类结果;
33、根据分类损失和预设尾部增广处理损失,更新所述分类模型的参数,直到所述分类模型收敛时,得到训练好的分类模型。
34、第五方面,本申请还提供了一种计算机程序产品。所述计算机程序产品,包括计
35、获取待训练数据的头部特征和尾部特征;
36、根据所述头部特征中的类别信息,提取所述头部特征的同类特征对;
37、根据所述头部特征的同类特征对,提取与类别无关的特征;
38、利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征;
39、融合所述自适应增广特征和所述尾部特征,得到增广尾部特征;
40、对所述头部特征、所述尾部特征和所述增广尾部特征进行特征提取后并进行分类预测,得到分类结果;
41、根据分类损失和预设尾部增广处理损失,更新所述分类模型的参数,直到所述分类模型收敛时,得到训练好的分类模型。
42、上述分类模型的训练方法、装置、计算机设备、存储介质和计算机程序产品,通过从头部特征的同类特征中,提取与类别无关的特征,引入非全局注意力机制,将与类别无关的特征与尾部特征融合,得到自适应增广特征,能够使得与类别无关特征和在不同空间位置上尾部特征,实现更为精细地自适应融合,提升了与类别无关特征与尾部特征的适配性,使得增广尾部特征与真实尾部数据相符合,进而将增广尾部特征与尾部特征融合,得到增广尾部特征。该方法使得在分类模型的训练阶段有效扩增了尾部特征空间,从而能够提升分类模型的分类性能。
1.一种分类模型的训练方法,其特征在于,所述方法包括:
2.根据权利要求1所述的方法,其特征在于,所述利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征,包括:
3.根据权利要求1所述的方法,其特征在于,所述预设尾部增广处理损失包括不同类别特征之间的对比损失;
4.根据权利要求1所述的方法,其特征在于,所述预设尾部增广处理损失包括增广尾部特征的类别损失;
5.根据权利要求1、3或4所述的方法,其特征在于,所述预设尾部增广处理损失包括增广处理的循环重构损失;所述方法还包括:
6.根据权利要求5所述的方法,其特征在于,所述预设尾部增广处理损失包括增广处理的模式寻找损失;
7.根据权利要求6所述的方法,其特征在于,所述预设尾部增广处理损失包括所述不同类别特征之间的对比损失、增广处理的循环重构损失和增广处理的模式寻找损失;
8.根据权利要求1所述的方法,其特征在于,所述分类模型包括多个特征提取层,以及设置在两个特征提取层之间的尾部增广层和设置在最后一个特征提取层后的分类预测层,各特征提取层之间连接,
9.根据权利要求8所述的方法,其特征在于,所述方法还包括:
10.一种分类模型的训练装置,其特征在于,所述装置包括: