本发明涉及图像分类的类别不均衡学习,特别是涉及一种不均衡目标分类方法、产品、介质及设备。
背景技术:
1、在海上目标分类问题中,采集的数据往往是比例失衡的,类别的不平衡会影响到分类模型的训练。在现实类别不均衡数据的模型鲁棒学习中引入虚拟类别均衡数据可以增加各类目标的特征多样性,但由于现实数据与虚拟数据之间存在域偏移问题(将数据划分为现实域和虚拟域,域之间存在域偏移问题),会导致模型学习产生混淆。
2、随着生成对抗网络的流行,部分研究基于生成对抗思想实现从源域到目标域的数据生成以解决域偏移问题。其中,基于现实域与虚拟域数据混合的域自适应学习方法mdlt在研究中重点讨论数据不均衡分布对模型域自适应学习过程所带来的影响,在隐层特征角度对不同域同类特征间的信息进行了有效共享,而这也为基于隐层特征进行数据增强的方法提供了便利。
3、mdlt针对多域长尾分布的目标分类问题(可理解为现实数据与虚拟数据之间存在的域偏移问题),旨在从来自多域的不平衡数据中学习,解决每个域内的标签不平衡、不同域之间的不同标签分布,并且最终模型能够泛化到所有域的所有类别上。具体来说,首先提出了领域类别转移图,用来刻画不同<领域,类别>对之间的可转移性(相似程度),基于这种定义的可转移性,直接决定了模型在mdlt任务上的表现。然后提出了boda用来提升模型在mdlt问题上的性能。
4、虽然在面对多域长尾分布的目标识别任务时,mdlt为基于隐层特征进行数据增强的方法提供了便利,但是其未对不同域中类别不同分布的情况进行深入优化,仅仅通过简单的损失加权缓解模型的偏置。
5、综上,现有方法比如mdlt方法主要是从损失函数的角度出发进行优化。对于仅使用简单加权平衡的mdlt方法而言,利用其进行现实域海上目标分类无法大幅提升现实域海上目标分类准确率。
技术实现思路
1、本发明的目的是提供一种不均衡目标分类方法、产品、介质及设备,能够大幅提升现实域海上目标分类准确率。
2、为实现上述目的,本发明提供了如下方案。
3、一方面,本发明提供一种不均衡目标分类方法,包括:
4、步骤s1:获取虚拟域训练数据集以及与所述虚拟域训练数据集对应的现实域训练数据集;所述虚拟域训练数据集和所述现实域训练数据集均包括不同类别的海上目标图像样本集;每一类别的海上目标图像样本集均包括多个海上目标图像样本;
5、步骤s2:获取现实域的批次训练样本和虚拟域的批次训练样本;所述现实域的批次训练样本是从所述现实域训练数据集中采样一个批次的海上目标图像样本得到的;所述虚拟域的批次训练样本是从所述虚拟域训练数据集中采样一个批次的海上目标图像样本得到的;
6、步骤s3:利用所述现实域的批次训练样本和所述虚拟域的批次训练样本,经过分类模型的特征提取主干网络部分,分别得到现实域隐层特征与虚拟域隐层特征;
7、步骤s4:基于所述现实域隐层特征计算当前训练时刻现实域各类隐层特征均值和协方差,并更新现实域各类隐层特征均值和协方差;基于所述虚拟域隐层特征计算当前训练时刻虚拟域各类隐层特征均值,并更新虚拟域各类隐层特征均值;
8、步骤s5:利用所述现实域训练数据集中的所有海上目标图像样本计算现实域的类混淆矩阵,基于所述类混淆矩阵计算现实域各类隐层特征修正均值和修正协方差;现实域各类隐层特征修正均值和修正协方差分别用于修正现实域各类隐层特征均值和协方差;
9、步骤s6:基于现实域各类隐层特征修正均值和修正协方差,对所述现实域的批次训练样本中来自尾部类样本的特征进行隐层特征增强,生成增广特征,利用所述增广特征和所述现实域隐层特征得到均衡后的现实域隐层特征;
10、步骤s7:基于均衡后的现实域隐层特征更新现实域各类隐层特征均值,得到更新后的现实域各类隐层特征均值;
11、步骤s8:基于所述虚拟域隐层特征、均衡后的现实域隐层特征、虚拟域各类隐层特征均值和更新后的现实域各类隐层特征均值计算域自适应训练损失,在所述域自适应训练损失的基础上添加监督学习损失,得到训练总损失,利用所述训练总损失对训练过程中的分类模型进行优化,得到优化后的分类模型;
12、步骤s9:判断所述训练总损失是否降低到设定阈值;
13、若否,则返回步骤s2;
14、若是,则执行步骤s10:将待分类海上目标图像输入所述优化后的分类模型,利用所述优化后的分类模型输出所述待分类海上目标图像对应的类别;所述待分类海上目标图像为现实场景中获取的真实的海上目标图像。
15、可选地,所述虚拟域训练数据集为类别均衡分布的虚拟域海上目标分类训练数据集;所述现实域训练数据集为类别不均衡分布的现实域海上目标分类训练数据集;
16、类别均衡分布的虚拟域海上目标分类训练数据集中海上目标图像样本集的类别总数和类别种类与类别不均衡分布的现实域海上目标分类训练数据集中海上目标图像样本集的类别总数和类别种类相同;
17、类别均衡分布的虚拟域海上目标分类训练数据集中不同类别的海上目标图像样本集中海上目标图像样本的数量是均衡的;类别不均衡分布的现实域海上目标分类训练数据集中不同类别的海上目标图像样本集中海上目标图像样本的数量是不均衡的。
18、可选地,所述现实域的批次训练样本是从所述现实域训练数据集中,使用实例平衡采样分布,采样一个批次的海上目标图像样本得到的;所述虚拟域的批次训练样本是从所述虚拟域训练数据集中,使用实例平衡采样分布,采样一个批次的海上目标图像样本得到的。
19、可选地,所述分类模型为域自适应分类模型。
20、可选地,利用所述增广特征和所述现实域隐层特征得到均衡后的现实域隐层特征,具体包括:
21、获取所述增广特征和所述现实域隐层特征的并集,将所述增广特征和所述现实域隐层特征的并集作为均衡后的现实域隐层特征。
22、可选地,所述基于均衡后的现实域隐层特征更新现实域各类隐层特征均值,得到更新后的现实域各类隐层特征均值,具体包括:
23、基于均衡后的现实域隐层特征计算现实域各类隐层特征均值,并更新现实域各类隐层特征均值,得到更新后的现实域各类隐层特征均值。
24、另一方面,本发明提供一种计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现所述不均衡目标分类方法。
25、另一方面,本发明还提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现所述不均衡目标分类方法。
26、再一方面,本发明提供一种计算机设备,包括:存储器、处理器以及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述计算机程序以实现所述不均衡目标分类方法。
27、根据本发明提供的具体实施例,本发明公开了以下技术效果:
28、本发明公开的不均衡目标分类方法、产品、介质及设备,通过现实域数据与虚拟域数据的混合学习,即通过虚实数据混合训练,解决了现实域数据中类别不均衡导致的分类效果差的问题,提升了现实域中尾部类的分类准确率,从而提升了现实域数据中目标分类性能,同时,本发明在网络中通过隐层特征增强,解决了域偏移问题,并提升了现实域中目标分类精度,最终达到了大幅提升现实域海上目标分类准确率的效果。