本申请实施例涉及人工智能,尤其涉及一种分类模型训练方法、分类方法及装置和存储介质。
背景技术:
1、在图像分类任务或视频分类任务中,常会出现极端不平衡的现象,少数类别含有大量的样本,大多数类别仅有少量样本,各类别的样本分布遵循长尾分布。以图像分类任务为例,可能存在大量未标记的图像数据,例如来自社交媒体或互联网的图片。同时,经过标记的图像数据可能呈现出长尾分布,即其中一些类别的图像数量远远大于其他类别。长尾分布这种极端不平衡会导致分类训练难以得到很好的效果。
2、相关技术中,为缓解训练集中类别不平衡的问题,根据训练集类别频率对开放集样本赋伪标签参与模型训练,可实现训练集类别分布的再平衡,提升模型针对尾部类别的平坦性和泛化性,其中,开放集是指训练集中不包含的类别。
3、然而,通过赋伪标签方式将开放集样本直接参与训练,有导致语义混淆的风险,继而影响对训练集类别表征的学习,导致模型分类性能较低。
技术实现思路
1、本申请实施例提供一种分类模型训练方法、分类方法及装置和存储介质,可以充分利用训练集和开放集处理训练集中类别不平衡问题,并不是直接将开放集的样本参与训练,避免了对训练集类别表征学习的影响,提高了所训练得到的分类模型的性能。
2、第一方面,本申请实施例提供一种分类模型训练方法,包括:
3、获取训练集和开放集,所述训练集包括多个训练样本和每个训练样本的分类标签,所述训练集的类别不包括所述开放集的类别,所述开放集包括多个无分类标签的样本;
4、根据所述训练集进行分类模型的第一阶段训练,得到分类模型的第一模型参数;
5、根据所述第一模型参数,确定所述分类模型的模型参数的高斯分布;
6、根据所述第一模型参数、所述模型参数的高斯分布和所述训练集,进行所述分类模型的第二阶段训练,得到已训练的分类模型,其中,迭代训练过程待使用的所述训练集的类别权重根据所述开放集、所述模型参数的高斯分布和所述训练集的初始类别权重确定。
7、第二方面,本申请实施例提供一种分类方法,包括:
8、获取目标数据;
9、将所述目标数据输入已训练的分类模型,输出所述目标数据的预测分类标签,所述分类模型根据第一方面所述的方法训练得到。
10、第三方面,本申请实施例提供一种分类模型训练装置,包括:
11、获取模块,用于获取训练集和开放集,所述训练集包括多个训练样本和每个训练样本的分类标签,所述训练集的类别不包括所述开放集的类别,所述开放集包括多个无分类标签的样本;
12、第一训练模块,用于根据所述训练集进行分类模型的第一阶段训练,得到分类模型的第一模型参数;
13、第二训练模块,用于根据所述第一模型参数,确定所述分类模型的模型参数的高斯分布,根据所述第一模型参数、所述模型参数的高斯分布和所述训练集,进行所述分类模型的第二阶段训练,得到已训练的分类模型,其中,迭代训练过程待使用的所述训练集的类别权重根据所述开放集、所述模型参数的高斯分布和所述训练集的初始类别权重确定。
14、第四方面,本申请实施例提供一种分类装置,包括:
15、获取模块,用于获取目标数据;
16、处理模块,用于将所述目标数据输入已训练的分类模型,输出所述目标数据的预测分类标签,所述分类模型根据第一方面所述的方法训练得到。
17、第五方面,本申请实施例提供一种计算机设备,包括:处理器和存储器,该存储器用于存储计算机程序,该处理器用于调用并运行该存储器中存储的计算机程序,以执行第一方面或第二方面的方法。
18、第六方面,本申请实施例提供一种计算机可读存储介质,包括指令,当其在计算机程序上运行时,使得所述计算机执行如第一方面或第二方面的方法。
19、第七方面,本申请实施例提供一种包含指令的计算机程序产品,当所述指令在计算机上运行时,使得所述计算机执行如第一方面或第二方面的方法。
20、综上,在本申请实施例中,通过获取到训练集和开放集后,采用两阶段的模型训练,第一阶段使用训练集训练初始分类模型,得到分类模型的第一模型参数,使得模型可较好拟合训练集分布;接着以第一模型参数为第二阶段训练的初始模型参数,根据第一模型参数确定分类模型的模型参数的高斯分布,根据第一模型参数、模型参数的高斯分布和训练集进行分类模型的第二阶段训练,得到已训练的分类模型,第二阶段训练过程中,迭代训练过程待使用的训练集的类别权重根据开放集、模型参数的高斯分布和训练集的初始类别权重确定,从而,在第二阶段训练时,利用了开放集的样本动态评估得到训练集的类别权重以进一步加权训练分类模型,从而可利用开放集样本平衡长尾分布,充分利用训练集和开放集处理训练集中类别不平衡问题,并不是直接将开放集的样本参与训练,因此避免了语义混淆问题,避免了对训练集类别表征学习的影响,提高了所训练得到的分类模型的性能。
1.一种分类模型训练方法,其特征在于,包括:
2.根据权利要求1所述的方法,其特征在于,根据所述训练集进行分类模型的第一阶段训练,得到分类模型的第一模型参数,包括:
3.根据权利要求1或2所述的方法,其特征在于,根据所述第一模型参数,确定所述分类模型的模型参数的高斯分布,包括:
4.根据权利要求1或2所述的方法,其特征在于,根据所述第一模型参数、所述模型参数的高斯分布和所述训练集,进行所述分类模型的第二阶段训练,包括:
5.根据权利要求4所述的方法,其特征在于,所述获取所述迭代训练过程待使用的所述训练集的类别权重,包括:
6.根据权利要求5所述的方法,其特征在于,所述根据所述开放集、所述模型参数的高斯分布和所述训练集的初始类别权重,确定所述迭代训练过程待使用的所述训练集的类别权重,包括:
7.根据权利要求4所述的方法,其特征在于,所述根据所述训练样本的分类标签、所述训练样本的预测分类标签和所述训练集的类别权重,构建第二损失函数,包括:
8.根据权利要求1所述的方法,其特征在于,所述分类模型包括网络模型和分类子模型,所述第一模型参数包括所述网络模型的模型参数和所述分类子模型的模型参数,所述根据所述第一模型参数,确定所述分类模型的模型参数的高斯分布,包括:
9.根据权利要求8所述的方法,其特征在于,所述根据所述第一模型参数、所述模型参数的高斯分布和所述训练集,进行所述分类模型的第二阶段训练,包括:
10.一种分类方法,其特征在于,包括:
11.一种分类模型训练装置,其特征在于,包括:
12.一种分类装置,其特征在于,包括:
13.一种计算机设备,其特征在于,包括:
14.一种计算机可读存储介质,其特征在于,包括指令,当其在计算机程序上运行时,使得所述计算机执行如权利要求1至9或10中任一项所述的方法。
15.一种包含指令的计算机程序产品,其特征在于,当所述指令在计算机上运行时,使得所述计算机执行权利要求1至9或10中任一项所述方法。