训练数据的增强方法、装置、计算机设备及存储介质与流程

文档序号:30754551发布日期:2022-07-13 10:16阅读:71来源:国知局
训练数据的增强方法、装置、计算机设备及存储介质与流程

1.本技术涉及模型训练领域,特别是涉及到一种训练数据的增强方法、装置、计算机设备及存储介质。


背景技术:

2.基于卷积神经网络的分类方法目前在多个任务上都表现出了良好的性能,但其泛化能力依然受限。在不同的工作中,模型的表现可能会有很大的差异,这影响了使用者对模型的信任度,而目前,对于解决模型识别结果出现错误的情况,大多采用增加训练数据量进行解决,无法从根本上对模型进行改进,导致模型的鲁棒性较低。


技术实现要素:

3.本技术的主要目的为提供一种训练数据的增强方法、装置、计算机设备及存储介质,旨在解决训练数据中存在明显影响模型结果错误数据导致模型的鲁棒性较低的问题。
4.为了实现上述发明目的,本技术提出一种训练数据的增强方法,所述方法包括:
5.获取训练数据集;
6.将所述训练数据集输入至分类模型进行预训练,获取分类错误的原始图像及分类数据;
7.将所述原始图像输入预设的残差网络,获取所述原始图像的特征;
8.根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征;
9.根据所述虚假特征对所述原始图像进行数据增强,得到目标图像;
10.根据所述目标图像生成增强训练数据集,以基于所述增强训练数据集对所述分类模型进行重新训练。
11.进一步地,所述根据所述互信息筛选所述特征中影响度大于预设值的虚假特征之后,还包括:
12.将所述虚假特征输入至决策树进行训练;
13.获取所述决策树中错误率最高的叶子节点,根据所述叶子节点确定目标虚假特征。
14.进一步地,所述根据所述虚假特征对所述原始图像进行数据增强,得到目标图像,包括:
15.将所述目标虚假特征进行归一化,并将归一化后的目标虚假特征放大至与所述原始图像相同尺寸,得到热力图;
16.将所述热力图与所述原始图像进行叠加融合以对所述原始图像进行数据增强,得到目标图像。
17.进一步地,所述将所述热力图与所述原始图像进行叠加融合以对所述原始图像进行数据增强,得到目标图像,包括:
18.根据所述热力图对所述原始图像中所述虚假特征对应区域进行掩蔽处理以对所述原始图像进行数据增强,得到目标图像。
19.进一步地,所述根据所述分类数据计算所述特征的互信息,包括:
20.获取所述分类数据中的错误标签数据;
21.计算每个所述特征与所述错误标签数据之间的互信息,得到每个所述特征的互信息。
22.进一步地,所述将所述原始图像输入预设的残差网络,获取所述原始图像的特征,包括:
23.将所述原始图像输入预设的残差网络,所述残差网络包括resnet50网络;
24.基于所述resnet50网络的最后一个卷积层提取所述原始图像的高维特征。
25.进一步地,所述根据所述目标图像生成增强训练数据集,以基于所述增强训练数据集对所述分类模型进行重新训练之后,还包括:
26.获取所述增强训练数据集的正确率;
27.当所述增强训练数据集的正确率大于预设值时,输出所述分类模型。
28.本技术还提供一种训练数据的增强装置,所述装置包括:
29.数据获取模块,用于获取训练数据集;
30.预训练模块,用于将所述训练数据集输入至分类模型进行预训练,获取分类错误的原始图像及分类数据;
31.特征提取模块,用于将所述原始图像输入预设的残差网络,获取所述原始图像的特征;
32.虚假特征模块,用于根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征;
33.数据增强模块,用于根据所述虚假特征对所述原始图像进行数据增强,得到目标图像;
34.重新训练模块,用于根据所述目标图像生成增强训练数据集,以基于所述增强训练数据集对所述分类模型进行重新训练。
35.本技术还提供一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现上述任一项所述训练数据的增强方法。
36.本技术还提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述任一项所述训练数据的增强方法。
37.本技术例提供了一种基于错误归因对分类模型的训练数据进行增强方法的方法,首先获取训练数据集,然后将未处理的训练数据集输入至分类模型进行预训练,通过预训练修正分类模型中的各个参数,使得分类模型能够以最大准确率识别出所述训练数据集中的图像对应的分类,而在预训练的过程中,分类模型对于训练数据集中的某些图像的分类产生错误,获取分类错误的原始图像以及获取所述原始图像的分类数据,将所述原始图像输入预设的残差网络,残差网络能够有效地识别不同的图像,并且提取图像中的能够为图像分类提供更多信息的特征,从而获取所述原始图像的特征,然后通过互信息表征特征对于原始图像分类错误的影响程度,根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征,根据所述虚假特征对所述原始图像
进行数据增强,得到目标图像,即根据所述目标图像以及该原始图像的正确分类生成增强训练数据集,然后将所述增强训练数据集再重新输入至分类模型中进行训练,以基于所述增强训练数据集对所述分类模型进行重新训练,分类模型根据增强后的增强训练数据集能够更加准确地调整分类模型的参数,从而准确地对图像进行分类,通过对分类模型出现识别错误的结果进行溯源,寻找与预测错误的结果具有相关性的虚假特征,并基于所述虚假特征对原始图像进行增强,从而能够得到特征混淆数据集即增强训练数据集,再基于增强训练数据集进行分类模型的训练,提高分类模型的训练数据的多样性,减少因环境或姿势改变而带来的预测结果的变化,从而提高模型的鲁棒性。
附图说明
38.图1为本技术训练数据的增强方法的一实施例流程示意图;
39.图2为本技术确定目标虚假特征的一实施例流程示意图;
40.图3为本技术对原始图像进行数据增强的一实施例流程示意图;
41.图4为本技术将所述热力图与所述原始图像进行叠加融合以对所述原始图像进行数据增强,得到目标图像的一实施例流程示意图;
42.图5为本技术计算特征的互信息的一实施例流程示意图;
43.图6为本技术获取所述原始图像的特征的一实施例流程示意图;
44.图7为本技术根据所述目标图像生成增强训练数据集,以基于所述增强训练数据集对所述分类模型进行重新训练之后的一实施例流程示意图;
45.图8为本技术训练数据的增强装置的一实施例结构示意图;
46.图9为本技术计算机设备的一实施例结构示意框图。
47.本技术目的的实现、功能特点及优点将结合实施例,参照附图做进一步说明。
具体实施方式
48.为了使本技术的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本技术进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本技术,并不用于限定本技术。
49.参照图1,本技术实施例提供一种训练数据的增强方法,所述训练数据的增强方法包括步骤s101-s106,对于所述训练数据的增强方法的各个步骤的详细阐述如下。
50.s101、获取训练数据集。
51.本实施例应用于模型训练的数据增强场景中,在该场景中,基于卷积神经网络的分类模型目前在多个任务被应用,但该分类模型的泛化能力依然受限,在不同的工作中,分类模型的表现可能会有很大的差异,分类模型的准确率差异较大,影响了使用者对模型的信任度。本实施例中,在对分类模型进行训练的过程,通过对训练集中的错误率比较高的集群进行数据增强,再基于数据增强后的训练集再次训练,具体的,首先获取训练数据集,在一种实施方式中,所述训练集中包含了相同分类的若干张图像,例如分类为“黑背信天翁”的100张图像,所述训练集中也包含了不同分类的图像,例如不仅包含分类为“黑背信天翁”的100张图像,还包含分类为“羊驼”的100张图像。
52.s102、将所述训练数据集输入至分类模型进行预训练,获取分类错误的原始图像
及分类数据。
53.本实施例中,在获取训练数据集之后,首先需要对分类模型进行训练,将未处理的训练数据集输入至分类模型进行训练的过程定义为预训练,即将所述训练数据集输入至分类模型进行预训练,通过预训练修正分类模型中的各个参数,使得分类模型能够以最大准确率识别出所述训练数据集中的图像对应的分类,而在预训练的过程中,分类模型对于训练数据集中的某些图像无法正确地进行识别与分类,即分类模型对于训练数据集中的某些图像的分类产生错误,此时,获取分类模型分类错误的图像,将分类错误的图像定义为原始图像,以及获取所述原始图像的分类数据,包括获取所述原始图像的正确的分类数据以及模型识别得到的错误的分类数据。
54.s103、将所述原始图像输入预设的残差网络,获取所述原始图像的特征。
55.本实施例中,在将所述训练数据集输入至分类模型进行预训练,获取分类错误的原始图像及分类数据之后,将所述原始图像输入预设的残差网络,所述预设的残差网络通过预先训练,能够识别图像中能对图像分类提供信息的特征,从而基于所述残差网络识别并获取所述原始图像的特征,在一种实施方式中,将所述残差网络在cub2011数据集上进行训练,训练过程使用交叉熵损失函数度量两个概率分布间的差异性信息,并采用adam优化器优化所述残差网络中的初始参数,并且设置所述残差网络的学习率根据结果准确率自适应调整,其中,初始学习率为0.1,训练好的残差网络能够有效地识别不同的图像,并且提取图像中的能够为图像分类提供更多信息的特征。
56.s104、根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征。
57.本实施例中,在将所述原始图像输入预设的残差网络,获取所述原始图像的特征之后,当提取了分类错误的原始图像包含的特征后,需要确定原始图像中的特征对分类错误的影响程度,具体的,通过互信息表征特征对于原始图像分类错误的影响程度,互信息(mutual information)是用来评价一个事件的出现对于另一个事件的出现所贡献的信息量,即根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征,在一种实施方式中,从所述特征中筛选出影响度排名前10个的特征作为虚假特征,所述虚假特征是分类模型对原始图像的识别分类出现错误的重要特征。
58.s105、根据所述虚假特征对所述原始图像进行数据增强,得到目标图像。
59.本实施例中,在根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征之后,根据所述虚假特征对所述原始图像进行数据增强,得到目标图像,在一种实施方式中,所述虚假特征对所述原始图像对应区域进行数据增强,所述数据增强包括所述原始图像对应区域的对比度,饱和度,替换所述原始图像对应区域为其他图像等,从而得到目标图像。
60.s106、根据所述目标图像生成增强训练数据集,以基于所述增强训练数据集对所述分类模型进行重新训练。
61.本实施例中,在根据所述虚假特征对所述原始图像进行数据增强,得到目标图像之后,根据所述目标图像生成增强训练数据集,即根据所述目标图像以及该原始图像的正确分类生成增强训练数据集,然后将所述增强训练数据集再重新输入至分类模型中进行训
练,以基于所述增强训练数据集对所述分类模型进行重新训练,由于对原始图像中的虚假特征进行增强,分类模型根据增强后的增强训练数据集能够更加准确地调整分类模型的参数,从而准确地对图像进行分类,通过对分类模型出现识别错误的结果进行溯源,寻找与预测错误的结果具有相关性的虚假特征,并基于所述虚假特征对原始图像进行增强,从而能够得到特征混淆数据集即增强训练数据集,再基于增强训练数据集进行分类模型的训练,提高分类模型的训练数据的多样性,减少因环境或姿势改变而带来的预测结果的变化,从而提高模型的鲁棒性。
62.本实施例提供了一种基于错误归因对分类模型的训练数据进行增强方法的方法,首先获取训练数据集,然后将未处理的训练数据集输入至分类模型进行预训练,通过预训练修正分类模型中的各个参数,使得分类模型能够以最大准确率识别出所述训练数据集中的图像对应的分类,而在预训练的过程中,分类模型对于训练数据集中的某些图像的分类产生错误,获取分类错误的原始图像以及获取所述原始图像的分类数据,将所述原始图像输入预设的残差网络,残差网络能够有效地识别不同的图像,并且提取图像中的能够为图像分类提供更多信息的特征,从而获取所述原始图像的特征,然后通过互信息表征特征对于原始图像分类错误的影响程度,根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征,根据所述虚假特征对所述原始图像进行数据增强,得到目标图像,即根据所述目标图像以及该原始图像的正确分类生成增强训练数据集,然后将所述增强训练数据集再重新输入至分类模型中进行训练,以基于所述增强训练数据集对所述分类模型进行重新训练,分类模型根据增强后的增强训练数据集能够更加准确地调整分类模型的参数,从而准确地对图像进行分类,通过对分类模型出现识别错误的结果进行溯源,寻找与预测错误的结果具有相关性的虚假特征,并基于所述虚假特征对原始图像进行增强,从而能够得到特征混淆数据集即增强训练数据集,再基于增强训练数据集进行分类模型的训练,提高分类模型的训练数据的多样性,减少因环境或姿势改变而带来的预测结果的变化,从而提高模型的鲁棒性。
63.在一个实施例中,如图2所示,所述根据所述互信息筛选所述特征中影响度大于预设值的虚假特征之后,还包括步骤s201-202:
64.s201,将所述虚假特征输入至决策树进行训练;
65.s202,获取所述决策树中错误率最高的叶子节点,根据所述叶子节点确定目标虚假特征。
66.本实施例中,在根据所述互信息筛选所述特征中影响度大于预设值的虚假特征之后,为了筛选更加准确地影响分类模型预测错误的特征,将所述虚假特征输入至决策树进行训练,然后获取所述决策树中错误率最高的叶子节点,即将每一个所述虚假特征输入至决策树中,基于决策树对数据和特征进行划分,即对分类数据与虚假特征进行划分,将分类数据作为根节点,将每一个虚假特征作为叶子节点进行训练,再获取每一个叶子节点的错误率,然后选取错误率最高的叶子节点,根据所述叶子节点确定目标虚假特征,从而筛选出对分类模型的错误预测结果影响最大的虚假特征,后续根据该虚假特征对原始图像进行数据增强,通过对分类模型出现识别错误的结果进行溯源,寻找与预测错误的结果具有相关性的虚假特征,并基于所述虚假特征对原始图像进行增强,从而提高训练数据的多样性。
67.在一个实施例中,如图3所示,所述根据所述虚假特征对所述原始图像进行数据增
强,得到目标图像,包括步骤s301-s302:
68.s301,将所述目标虚假特征进行归一化,并将归一化后的目标虚假特征放大至与所述原始图像相同尺寸,得到热力图;
69.s302,将所述热力图与所述原始图像进行叠加融合以对所述原始图像进行数据增强,得到目标图像。
70.本实施例中,在根据所述虚假特征对所述原始图像进行数据增强,得到目标图像的过程中,将所述目标虚假特征进行归一化,在一种实施方式中,将所述目标虚假特征归一化到[0,1]区间,然后将归一化后的目标虚假特征进行缩放,具体的,将所述归一化后的目标虚假特征放大至与所述原始图像相同尺寸,得到热力图,从而能够准确地确定所述目标虚假特征在所述原始图像中的位置,再将所述热力图与所述原始图像进行叠加融合,从而对所述原始图像进行数据增强,得到目标图像,从而提高训练数据的多样性。
[0071]
在一个实施例中,如图4所示,所述将所述热力图与所述原始图像进行叠加融合以对所述原始图像进行数据增强,得到目标图像,还包括步骤s401:
[0072]
s401,根据所述热力图对所述原始图像中所述虚假特征对应区域进行掩蔽处理以对所述原始图像进行数据增强,得到目标图像。
[0073]
本实施例中,在根将所述热力图与所述原始图像进行叠加融合以对所述原始图像进行数据增强,得到目标图像的过程中,根据所述热力图对所述原始图像中所述虚假特征对应区域进行掩蔽处理以对所述原始图像进行数据增强,得到目标图像,在一种实施方式中,根据所述热力图对所述原始图像中所述虚假特征对应区域进行高斯模糊处理,从而对所述原始图像进行数据增强,掩蔽后的虚假特征能够降低对分类模型的分类结果的影响程度,从而提高训练数据的多样性。
[0074]
在一个实施例中,如图5所示,所述根据所述分类数据计算所述特征的互信息,还包括步骤s501-s502:
[0075]
s501,获取所述分类数据中的错误标签数据;
[0076]
s502,计算每个所述特征与所述错误标签数据之间的互信息,得到每个所述特征的互信息。
[0077]
本实施例中,在根据所述分类数据计算所述特征的互信息的过程中,获取所述分类数据中的错误标签数据,若干张属于同一个分类结果的原始图像,由于分类模型识别过程中出现错误,导致该些原始图像得到的错误分类结果多种多样,将每一样的错误分类结果转化为错误标签数据,然后计算每个所述特征与所述错误标签数据之间的互信息,得到每个所述特征的互信息,基于所述互信息对所述特征进行筛选,能够准确、高效地确定与预测错误的结果具有相关性的虚假特征,并基于所述虚假特征对原始图像进行增强,从而提高训练数据的多样性。
[0078]
在一个实施例中,如图6所示,所述将所述原始图像输入预设的残差网络,获取所述原始图像的特征,还包括步骤s601-s602:
[0079]
s601,将所述原始图像输入预设的残差网络,所述残差网络包括resnet50网络;
[0080]
s602,基于所述resnet50网络的最后一个卷积层提取所述原始图像的高维特征。
[0081]
本实施例中,在将所述原始图像输入预设的残差网络,获取所述原始图像的特征的过程中,将所述原始图像输入预设的残差网络,所述残差网络包括resnet50网络,然后基
于所述resnet50网络的最后一个卷积层提取所述原始图像的高维特征,通过resnet50网络的多层卷积层对所述原始图像进行特征提取,并且仅仅保留所述resnet50网络的最后一个卷积层提取所述原始图像的高维特征,能够有效地提取所述原始图像中为分类结果提供信息的特征,从而减少提取的特征的数量,减少计算量,提高特征提取的准确性。
[0082]
在一个实施例中,如图7所示,所述根据所述目标图像生成增强训练数据集,以基于所述增强训练数据集对所述分类模型进行重新训练之后,还包括步骤s701-s702:
[0083]
s701,获取所述增强训练数据集的正确率;
[0084]
s702,当所述增强训练数据集的正确率大于预设值时,输出所述分类模型。
[0085]
本实施例中,在根据所述目标图像生成增强训练数据集,以基于所述增强训练数据集对所述分类模型进行重新训练之后,获取所述增强训练数据集的正确率,当所述增强训练数据集的正确率大于预设值时,输出所述分类模型,而当所述增强训练数据集的正确率低于预设值时,可以再次筛选其他虚假特征后重新生成目标图像与增强训练数据集,减少因环境或姿势改变而带来的预测结果的变化,从而提高模型的鲁棒性。
[0086]
参照图8,本技术还提供一种训练数据的增强装置,包括:
[0087]
数据获取模块101,用于获取训练数据集;
[0088]
预训练模块102,用于将所述训练数据集输入至分类模型进行预训练,获取分类错误的原始图像及分类数据;
[0089]
特征提取模块103,用于将所述原始图像输入预设的残差网络,获取所述原始图像的特征;
[0090]
虚假特征模块104,用于根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征;
[0091]
数据增强模块105,用于根据所述虚假特征对所述原始图像进行数据增强,得到目标图像;
[0092]
重新训练模块106,用于根据所述目标图像生成增强训练数据集,以基于所述增强训练数据集对所述分类模型进行重新训练。
[0093]
如上所述,可以理解地,本技术中提出的所述训练数据的增强装置的各组成部分可以实现如上所述训练数据的增强方法任一项的功能。
[0094]
在一个实施例中,所述根据所述互信息筛选所述特征中影响度大于预设值的虚假特征之后,还包括:
[0095]
将所述虚假特征输入至决策树进行训练;
[0096]
获取所述决策树中错误率最高的叶子节点,根据所述叶子节点确定目标虚假特征。
[0097]
在一个实施例中,所述根据所述虚假特征对所述原始图像进行数据增强,得到目标图像,包括:
[0098]
将所述目标虚假特征进行归一化,并将归一化后的目标虚假特征放大至与所述原始图像相同尺寸,得到热力图;
[0099]
将所述热力图与所述原始图像进行叠加融合以对所述原始图像进行数据增强,得到目标图像。
[0100]
在一个实施例中,所述将所述热力图与所述原始图像进行叠加融合以对所述原始
图像进行数据增强,得到目标图像,包括:
[0101]
根据所述热力图对所述原始图像中所述虚假特征对应区域进行掩蔽处理以对所述原始图像进行数据增强,得到目标图像。
[0102]
在一个实施例中,所述根据所述分类数据计算所述特征的互信息,包括:
[0103]
获取所述分类数据中的错误标签数据;
[0104]
计算每个所述特征与所述错误标签数据之间的互信息,得到每个所述特征的互信息。
[0105]
在一个实施例中,所述将所述原始图像输入预设的残差网络,获取所述原始图像的特征,包括:
[0106]
将所述原始图像输入预设的残差网络,所述残差网络包括resnet50网络;
[0107]
基于所述resnet50网络的最后一个卷积层提取所述原始图像的高维特征。
[0108]
在一个实施例中,所述根据所述目标图像生成增强训练数据集,以基于所述增强训练数据集对所述分类模型进行重新训练之后,还包括:
[0109]
获取所述增强训练数据集的正确率;
[0110]
当所述增强训练数据集的正确率大于预设值时,输出所述分类模型。
[0111]
参照图9,本技术实施例中还提供一种计算机设备,该计算机设备可以是移动终端,其内部结构可以如图9所示。该计算机设备包括通过系统总线连接的处理器、存储器、网络接口和显示装置及输入装置。其中,该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机设备的显示装置用于显示离线应用。该计算机设备的输入装置用于接收用户在离线应用的输入。该计算机设计的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质。该非易失性存储介质存储有操作系统、计算机程序和数据库。该计算机设备的数据库用于存放原始数据。该计算机程序被处理器执行时以实现一种训练数据的增强方法。
[0112]
上述处理器执行上述的训练数据的增强方法,所述方法包括:获取训练数据集;将所述训练数据集输入至分类模型进行预训练,获取分类错误的原始图像及分类数据;将所述原始图像输入预设的残差网络,获取所述原始图像的特征;根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征;根据所述虚假特征对所述原始图像进行数据增强,得到目标图像;根据所述目标图像生成增强训练数据集,以基于所述增强训练数据集对所述分类模型进行重新训练。
[0113]
所述计算机设备提供了一种基于错误归因对分类模型的训练数据进行增强方法的方法,首先获取训练数据集,然后将未处理的训练数据集输入至分类模型进行预训练,通过预训练修正分类模型中的各个参数,使得分类模型能够以最大准确率识别出所述训练数据集中的图像对应的分类,而在预训练的过程中,分类模型对于训练数据集中的某些图像的分类产生错误,获取分类错误的原始图像以及获取所述原始图像的分类数据,将所述原始图像输入预设的残差网络,残差网络能够有效地识别不同的图像,并且提取图像中的能够为图像分类提供更多信息的特征,从而获取所述原始图像的特征,然后通过互信息表征特征对于原始图像分类错误的影响程度,根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征,根据所述虚假特征对所述原始图像进行数据增强,得到目标图像,即根据所述目标图像以及该原始图像的正确分类生
成增强训练数据集,然后将所述增强训练数据集再重新输入至分类模型中进行训练,以基于所述增强训练数据集对所述分类模型进行重新训练,分类模型根据增强后的增强训练数据集能够更加准确地调整分类模型的参数,从而准确地对图像进行分类,通过对分类模型出现识别错误的结果进行溯源,寻找与预测错误的结果具有相关性的虚假特征,并基于所述虚假特征对原始图像进行增强,从而能够得到特征混淆数据集即增强训练数据集,再基于增强训练数据集进行分类模型的训练,提高分类模型的训练数据的多样性,减少因环境或姿势改变而带来的预测结果的变化,从而提高模型的鲁棒性。
[0114]
本技术一实施例还提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被所述处理器执行时实现一种训练数据的增强方法,包括步骤:获取训练数据集;将所述训练数据集输入至分类模型进行预训练,获取分类错误的原始图像及分类数据;将所述原始图像输入预设的残差网络,获取所述原始图像的特征;根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征;根据所述虚假特征对所述原始图像进行数据增强,得到目标图像;根据所述目标图像生成增强训练数据集,以基于所述增强训练数据集对所述分类模型进行重新训练。
[0115]
所述计算机可读存储介质提供了一种基于错误归因对分类模型的训练数据进行增强方法的方法,首先获取训练数据集,然后将未处理的训练数据集输入至分类模型进行预训练,通过预训练修正分类模型中的各个参数,使得分类模型能够以最大准确率识别出所述训练数据集中的图像对应的分类,而在预训练的过程中,分类模型对于训练数据集中的某些图像的分类产生错误,获取分类错误的原始图像以及获取所述原始图像的分类数据,将所述原始图像输入预设的残差网络,残差网络能够有效地识别不同的图像,并且提取图像中的能够为图像分类提供更多信息的特征,从而获取所述原始图像的特征,然后通过互信息表征特征对于原始图像分类错误的影响程度,根据所述分类数据计算所述特征的互信息,并根据所述互信息筛选所述特征中影响度大于预设值的虚假特征,根据所述虚假特征对所述原始图像进行数据增强,得到目标图像,即根据所述目标图像以及该原始图像的正确分类生成增强训练数据集,然后将所述增强训练数据集再重新输入至分类模型中进行训练,以基于所述增强训练数据集对所述分类模型进行重新训练,分类模型根据增强后的增强训练数据集能够更加准确地调整分类模型的参数,从而准确地对图像进行分类,通过对分类模型出现识别错误的结果进行溯源,寻找与预测错误的结果具有相关性的虚假特征,并基于所述虚假特征对原始图像进行增强,从而能够得到特征混淆数据集即增强训练数据集,再基于增强训练数据集进行分类模型的训练,提高分类模型的训练数据的多样性,减少因环境或姿势改变而带来的预测结果的变化,从而提高模型的鲁棒性。
[0116]
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本技术所提供的和实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可以包括只读存储器(rom)、可编程rom(prom)、电可编程rom(eprom)、电可擦除可编程rom(eeprom)或闪存。易失性存储器可包括随机存取存储器(ram)或者外部高速缓冲存储器。作为说明而非局限,ram以多种形式可得,诸如静态ram(sram)、动态ram(dram)、同步dram(sdram)、双速据率sdram(ssrsdram)、增强
型sdram(esdram)、同步链路(synchlink)dram(sldram)、存储器总线(rambus)直接ram(rdram)、直接存储器总线动态ram(drdram)、以及存储器总线动态ram(rdram)等。
[0117]
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、装置、物品或者方法不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、装置、物品或者方法所固有的要素。在没有更多限制的情况下,由语句“包括一个
……”
限定的要素,并不排除在包括该要素的过程、装置、物品或者方法中还存在另外的相同要素。
[0118]
以上所述仅为本技术的优选实施例,并非因此限制本技术的专利范围,凡是利用本技术说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本技术的专利保护范围内。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1