本技术涉及人工智能(artificial intelligence,ai),尤其涉及一种模型训练方法以及训练设备。
背景技术:
1、在ai领域中,参数较多且复杂度较高的大型模型对计算资源和存储空间的需求较高。该大型模型模型难以部署在计算资源和存储空间都有限的终端设备。参数相对少且复杂度相对低的小型模型却达不到大前述大型模型的性能。因此,在小型模型的训练过程中,可以通过小型模型学习大型模型的先验知识更新小型模型的参数,达到提升小型模型的性能的目的。
2、在相关技术中,在小型模型的训练过程中,将大型模型对输入数据进行推理得到的输出数据和输入数据的标签数据共同作为小型模型的标签数据。但是,大型模型一般通过输入数据、以及输入数据对应的标签数据预先训练获得,大型模型对输入数据推理得到的输出数据、与输入数据的标签数据高度相似,导致大型模型的先验知识无法有效传递给小型模型,小型模型的性能提升有限。
技术实现思路
1、本技术提供了一种模型训练方法以及训练设备,可以使第一模型充分学习第二模型的先验知识,解决第一模型性能提升有限的问题。
2、第一方面,本技术提供了一种模型训练方法。该方法包括:将第二输入数据输入第一模型,确定第一输出数据;将所述第二输入数据输入第二模型,确定第二输出数据,其中,所述第二模型至少基于第一输入数据训练获得,所述第二输入数据基于所述第一输入数据进行处理得到;根据所述第一输出数据和所述第二输出数据确定所述第一模型的第一蒸馏损失值;根据所述第一蒸馏损失值更新所述第一模型的参数。需要说明的是,本技术对利用第一模型得到第一输出数据和利用第二模型得到第二输出数据的顺序不做具体限制。即可以先利用第一模型得到第一输出数据后利用第二模型得到第二输出数据,也可以先利用第二模型得到第二输出数据后利用第一模型得到第一输出数据。
3、上述方案,在第一模型的训练过程中,对第一输入数据进行域内再生,得到无对应标签数据的第二输入数据。然后,将第二输入数据分别作为第一模型和第二模型的输入,使用第一模型根据对第二输入数据进行推理得到的第一输出数据、和第二模型对第二输入数据进行推理得到的第二输出数据更新第一模型的参数,可以使第一模型充分学习第二模型的先验知识,从而将第二模型的先验知识有效传递给第一模型,充分提高第一模型的性能。其中,第二模型可以包括大型模型,第一模型可以包括该大型模型对应的小型模型,大型模型和小型模型用于实现相同的功能。
4、在第一方面的一种可能的实施方式中,所述第一输入数据包括第一图像,所述第一图像的标签数据包括第二图像,所述第二图像的图像特征与所述第一图像的图像特征相同,所述第一图像的分辨率小于所述第二图像的分辨率,所述第一图像的大小小于所述第二图像的大小,所述方法还包括:对所述第二图像进行裁剪,得到所述第二输入数据。
5、上述方案中,通过裁剪操作,可以从第一图像的标签数据中得到其衍生的第二输入数据,可以避免引入额外的数据而增加训练成本。
6、在第一方面的一种可能的实施方式中,所述第二输入数据包括所述第二图像中的目标图像块,所述目标图像块的大小与所述第一图像的大小相同,所述目标图像块与所述第一图像的损失值等于或大于目标阈值。
7、上述方案中,根据第一图像与第二图像中的各个图像块的损失值以及目标阈值确定第二输入数据,可以实现对第二图像的自适应裁剪。
8、在第一方面的一种可能的实施方式中,所述第一输入数据包括第一图像,所述方法还包括:基于所述第一模型的超分倍数对所述第一图像进行下采样得到所述第二输入数据。
9、上述方案中,通过下采样,可以从第一图像中得到其衍生的第二输入数据,可以避免引入额外的数据而增加训练成本。
10、在第一方面的一种可能的实施方式中,所述第一输入数据包括第一图像,所述第一图像的标签数据包括第二图像,所述第二图像的图像特征与所述第一图像的图像特征相同,所述第一图像的分辨率小于所述第二图像的分辨率,所述第一图像的大小与所述第二图像的大小相同,所述第二输入数据包括所述第二图像。
11、上述方案中,在第二图像(第一图像的标签数据)与第一图像的大小相同的情况下,可以将第二图像作为其衍生的第二输入数据,可以避免引入额外的数据而增加训练成本。
12、在第一方面的一种可能的实施方式中,在将所述第二输入数据输入第一模型之前,所述方法还包括:按照目标操作对所述第二输入数据进行数据增强。在根据所述第一输出数据和所述第二输出数据确定所述第一模型的第一蒸馏损失值之前,所述方法还包括:按照所述目标操作对应的逆操作对所述第一输出数据逆增强。
13、上述方案中,在第一模型对第二输入数据进行推理前,对第二输入数据进行数据增强,以及在第一模型对经数据增强的第二输入数据进行推理后,对第一模型输出的第一输出数据进行数据逆增强,这样可以通过对第一模型实施标签一致性正则约束,从而提高第一模型的鲁棒性和泛化性。
14、在第一方面的一种可能的实施方式中,所述目标操作包括图像反色、水平翻转、垂直翻转、顺时针旋转、和/或逆时针旋转中的一种或多种。
15、在第一方面的一种可能的实施方式中,所述第一输入数据包括带噪音频,所述带噪音频的标签数据包括所述带噪音频对应的原音音频,所述确定第一模型的输入数据对应的第二输入数据包括:从所述原音音频中截取部分音频作为所述第二输入数据;或者,将所述原音音频作为所述第二输入数据。其中,原音音频是指不包含带噪音频中的噪声的音频。
16、在第一方面的一种可能的实施方式中,所述根据所述第一蒸馏损失值更新所述第一模型的参数,包括:根据所述第一蒸馏损失值、所述第一模型的第二蒸馏损失值、和所述第一模型的重建损失值,确定所述第一模型的总损失值,其中,所述第二蒸馏损失值根据所述第一模型的第三输出数据和所述第二模型的第四输出数据确定,所述重建损失值根据所述第三输出数据和所述第一输入数据的标签数据确定,所述第三输出数据通过将所述第一输入数据输入所述第一模型获得,所述第四输出数据通过将所述第一输入数据输入所述第二模型获得;根据所述总损失值更新所述第一模型的参数。
17、在第一方面的一种可能的实施方式中,所述第一模型包括学生模型,所述第二模型包括所述学生模型对应的教师模型。
18、第二方面,本技术提供一种模型训练装置。该装置包括:确定模块和更新模块。
19、其中,确定模块用于将第二输入数据输入第一模型,确定第一输出数据,以及将所述第二输入数据输入第二模型,确定第二输出数据。其中,所述第二模型至少基于第一输入数据训练获得,所述第二输入数据基于所述第一输入数据进行处理得到。
20、其中,更新模块用于根据所述第一输出数据和所述第二输出数据确定所述第一模型的第一蒸馏损失值,以及根据所述第一蒸馏损失值更新所述第一模型的参数。
21、在第二方面的一种可能的实施方式中,所述第一输入数据包括第一图像,所述第一图像的标签数据包括第二图像,所述第二图像的图像特征与所述第一图像的图像特征相同,所述第一图像的分辨率小于所述第二图像的分辨率,所述第一图像的大小小于所述第二图像的大小,所述确定模块还用于:对所述第二图像进行裁剪,得到所述第二输入数据。
22、在第二方面的一种可能的实施方式中,所述第二输入数据包括所述第二图像中的目标图像块,所述目标图像块的大小与所述第一图像的大小相同,所述目标图像块与所述第一图像的损失值等于或大于目标阈值。
23、在第二方面的一种可能的实施方式中,所述第一输入数据包括第一图像,所述确定模块还用于:基于所述第一模型的超分倍数对所述第一图像进行下采样得到所述第二输入数据。
24、在第二方面的一种可能的实施方式中,所述第一输入数据包括第一图像,所述第一图像的标签数据包括第二图像,所述第二图像的图像特征与所述第一图像的图像特征相同,所述第一图像的分辨率小于所述第二图像的分辨率,所述第一图像的大小与所述第二图像的大小相同,所述第二输入数据包括所述第二图像。
25、在第二方面的一种可能的实施方式中,在将所述第二输入数据输入第一模型之前,所述确定模块还用于:按照目标操作对所述第二输入数据进行数据增强;在根据所述第一输出数据和所述第二输出数据确定所述第一模型的第一蒸馏损失值之前,所述确定模块还用于:按照所述目标操作对应的逆操作对所述第一输出数据逆增强。
26、在第二方面的一种可能的实施方式中,所述目标操作包括图像反色、水平翻转、垂直翻转、顺时针旋转、和/或逆时针旋转中的一种或多种。
27、在第二方面的一种可能的实施方式中,所述第一输入数据包括带噪音频,所述第一输入数据的标签数据包括所述带噪音频对应的原音音频,所述确定模块还用于:从所述原音音频中截取部分音频作为所述第二输入数据;或者,将所述原音音频作为所述第二输入数据。
28、在第二方面的一种可能的实施方式中,所述更新模块还用于:根据所述第一蒸馏损失值、所述第一模型的第二蒸馏损失值、和所述第一模型的重建损失值,确定所述第一模型的总损失值,其中,所述第二蒸馏损失值根据所述第一模型的第三输出数据和所述第二模型的第四输出数据确定,所述重建损失值根据所述第三输出数据和所述第一输入数据的标签数据确定,所述第三输出数据通过将所述第一输入数据输入所述第一模型获得,所述第四输出数据通过将所述第一输入数据输入所述第二模型获得:根据所述总损失值更新所述第一模型的参数。
29、在第二方面的一种可能的实施方式中,所述第一模型包括学生模型,所述第二模型包括所述学生模型对应的教师模型。
30、第三方面,本技术提供一种训练设备。训练设备包括:处理器和存储器,所述处理器用于执行存储于所述存储器内的计算机程序以实现上述第一方面或第一方面的任意一种可能的实现方式所提供的模型训练方法。
31、第四方面,本技术提供一种计算机可读存储介质,包括指令,当所述指令在计算机上运行时,使得所述计算机实现上述第一方面或第一方面的任意一种可能的实现方式所提供的模型训练方法。
32、第五方面,本技术提供一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机实现上述第一方面或第一方面的任意一种可能的实现方式所提供的模型训练方法。
33、上述提供的任一种装置或计算机存储介质或计算机程序产品,均用于执行上文所提供的方法,因此,其所能达到的有益效果可参考上文提供的对应方法中的对应方案的有益效果,此处不再赘述。