本技术涉及机器学习,尤其涉及一种基于知识蒸馏的模型训练方法、装置及电子设备。
背景技术:
1、目标检测的任务是找出图像中感兴趣的目标,确定它们的位置和类别。实现目标检测的途径通常是训练出目标检测模型并通过训练好的目标检测模型实现目标检测功能。对于训练得到的目标检测模型而言,模型的复杂度通常会随着模型精度的提高而越加复杂,进而对于部署资源的要求也就会越高。知识蒸馏是一种通过引入教师模型来诱导学生模型训练的模型压缩方法,可以实现从教师模型对学生模型的知识迁移。故而,通过知识蒸馏的方法,可以将训练得到的复杂模型作为教师模型,并将教师模型的知识迁移至结构更为简单的学生模型,使得学生模型在保留教师模型精度的同时,解决模型部署资源不足问题。
2、上述基于知识蒸馏所提及的教师模型与学生模型可以为同构模型,也可以为异构模型。在本技术中,目标检测模型训练过程所使用的两阶段检测网络的教师模型与单阶段检测网络的学生模型,其中学生模型与教师模型为异构模型。而异构模型所带来的预测框尺寸不匹配、类别得分意义不匹配及特征尺寸不匹配等问题目前尚未有完整的解决方法。因此,本技术针提出了一种基于知识蒸馏的模型训练方法、装置、电子设备及存储介质,以解决上述提及的一系列异构蒸馏问题。
技术实现思路
1、本技术提供一种基于知识蒸馏的模型训练方法、装置及电子设备,可以解决采用知识蒸馏的方法训练目标检测模型时,对于单阶段检测网络的学生模型以及两阶段检测网络的教师模型,也即学生模型与教师模型为异构模型的情况下,预测框尺寸不匹配、类别得分意义不匹配及特征尺寸不匹配等一系列异构蒸馏问题。
2、为达到上述目的,本技术实施例采用如下技术方案:
3、第一方面,本技术提供了一种基于知识蒸馏的模型训练方法,该方法中,知识蒸馏中所涉及的教师模型为两阶段检测网络、学生模型为单阶段检测网络,也即教师模型与学生模型为异构模型,将教师模型中的知识蒸馏迁移至学生模型,经过知识蒸馏过程的学生模型为训练得到的模型,用于对输入图像进行目标检测,该方法包括:获取已经训练好的教师模型;将学生模型特征层的特征尺度与教师模型特征层的特征尺度对齐,并确定特征蒸馏损失,特征蒸馏损失表示特征尺度对齐后,教师模型与学生模型特在特征层的差异程度;将学生模型输出层的预测框及对应的概率分布与教师模型输出层的预测框及对应的概率分布对齐,并确定输出蒸馏损失,输出蒸馏损失表示预测框及对应的概率分布对齐后,教师模型与学生模型在输出层的差异程度,概率分布为预测框的输出的概率分布;依据特征蒸馏损失及输出蒸馏损失改进学生模型的损失函数;基于改进后的损失函数训练学生模型,得到训练完成的模型。
4、本技术实施例提供的技术方案至少带来以下有益效果:
5、本技术采用知识蒸馏的方式训练得到用于目标检测的学生模型,且知识蒸馏中所涉及的教师模型为两阶段检测网络、学生模型为单阶段检测网络,也即教师模型与学生模型为异构模型。在训练时,本技术通过将学生模型输出层的预测框及对应的概率分布与教师模型输出层的预测框及对应的概率分布对齐,解决了学生模型与教师模型输出信息不对齐的问题;通过将学生模型特征层的特征尺度与教师模型特征层的特征尺度对齐,解决了学生模型与教师模型特征信息不对齐的问题;通过确定特征蒸馏损失和输出蒸馏损失对学生模型的损失函数进行改进,并通过改进后的损失函数指导学生模型的训练,可以使学生模型学习到教师模型的知识,在不改变学生模型结构的前提下,使学生模型拥有接近教师模型的泛化能力和功能。综上所述,本技术一方面提出了完整的异构蒸馏方案,在所涉及的教师模型为两阶段检测网络、学生模型为单阶段检测网络,也即教师模型与学生模型为异构模型的情况下,采用知识蒸馏的方式实现了目标检测模型的训练。另一方面,通过利用结构更复杂,功能更完善的教师模型对学生模型进行知识蒸馏,可以在不改变学生模型结构的情况下,令学生模型具有与教师模型相近的能力。同时,由于学生模型结构简单,对于部署资源性能的要求较低,训练得到的用于目标检测的学生模型的应用场景也会更广阔。
6、在一种可能的实现方式中,上述将学生模型特征层的特征尺度与教师模型特征层的特征尺度对齐,包括:获取教师模型的教师特征集合;其中,教师特征集合中的教师特征是教师模型的输入的第一图像中目标区域的特征,目标区域为基于锚框确定的区域,目标区域的特征包括目标区域的前景区域特征及背景区域特征;在学生模型的各层特征中获取的学生特征集合;其中,学生特征集合中的学生特征是目标区域在学生模型中的特征;将教师特征集合与学生特征集合中的特征尺度转换至同一维度。
7、在该种可能的实现方式中,提出了一种解决知识蒸馏过程中教师模型与学生模型特征尺度不匹配问题的方法,将学生模型及教师模型的特征信息转换至同一维度,有助于提升方案的可实施性。
8、在一种可能的实现方式中,上述在学生模型的各层特征中获取的学生特征集合,包括:基于输入学生模型的第一图像中的目标区域提取各层的学生区域特征,得到学生特征集合。
9、在一种可能的实现方式中,上述确定在特征层的特征蒸馏损失,包括:基于注意力机制或余弦相似度算法计算教师特征集合中的各教师特征与学生特征集合中的各学生特征之间的特征相似度,并依据特征相似度确定特征蒸馏损失。
10、在该种可能的实现方式中,可以得到教师特征集合中的各特征与学生集合中的各特征两两之间的特征相似度,据此可以进一步计算出特征对齐过程中的特征蒸馏损失,以便后续基于该特征蒸馏损失改进学生模型的损失函数,使训练得到的学生模型具有更好的目标检测性能。
11、在一种可能的实现方式中,上述将学生模型输出层的预测框及对应的概率分布与教师模型输出层的预测框及对应的概率分布对齐,包括:计算学生模型输出层的预测框与教师模式输出层的预测框的交并比,并依据交并比在多个学生模型输出层的预测框中选择与对应的教师模式输出层的预测框匹配的目标预测框;获取目标预测框对应的目标概率分布,目标概率分布为目标预测框的输出的概率分布;将目标概率分布与教师模式输出层的预测框对应的概率分布对齐。
12、在该种可能的实现方式中,通过计算各教师预测框与学生预测框的交并比,选出与教师预测框最匹配的学生预测框,并将选出的学生预测框的概率分布与对应的教师预测框的概率分布进行含义对齐,解决了异构蒸馏中预测框尺寸不一致以及类别得分意义不对齐的问题。
13、在一种可能的实现方式中,上述计算学生模型输出层的预测框与教师模式输出层的预测框的交并比,并依据交并比在多个学生模型输出层的预测框中选择与对应的教师模式输出层的预测框匹配的目标预测框,包括:将输入图像对应的特征图划分为多个网格,并计算各教师预测框的中心点在特征图中的中心点所属网格;在每一个教师预测框对应的中心点所属网格中,计算该中心点所属网格中各学生预测框与教师预测框的交并比;选择数值最大的所述交并比对应的学生预测框作为目标预测框。
14、在该种可能的实现方式中,确定每个教师预测框的中心点所属的网格,计算各教师预测框与所有在其中心点所属网格中的各学生预测框的交并比,并根据计算出的交并比选择出最匹配的学生预测框,提出了方案的可实施性。此外,本方式将匹配范围限制在了同一特征网格内,可以应对部分学生模型训练过程中的偏移值有区间限制的情况。
15、在一种可能的实现方式中,上述计计算学生模型输出层的预测框与教师模式输出层的预测框的交并比,并依据交并比在多个学生模型输出层的预测框中选择与对应的教师模式输出层的预测框匹配的目标预测框,包括:对于每一个教师预测框,计算该教师预测框与特征图中各学生预测框的交并比;选择数值最大的交并比对应的学生预测框作为目标预测框。
16、在该种可能的实现方式中,通过计算各教师预测框与特征图中所有学生预测框的交并比来为各教师预测框选择最匹配的学生预测框,可以实现学生模型预测框没有区间限制情况下的预测框对齐,有助于提升方案的可实施性。
17、在一种可能的实现方式中,上述确定输出蒸馏损失,包括:基于学生模型的输出概率分布与前景置信度的乘积以及教师模型输出概率分布,计算输出蒸馏损失。
18、在该种可能的实现方式中,提出了一种确定输出蒸馏损失的具体实现方法,提升了方案的可实施性。此外,通过计算输出蒸馏损失,可在后续基于该输出蒸馏损失对学生模型的损失函数进行改进,以使训练得到的学生模型具有更好的目标检测性能。
19、在一种可能的实现方式中,上述确定在输出层的输出蒸馏损失,包括:对所述目标概率分布对应的所述学生模型的输出概率与所述教师模型的输出概率归一化,并基于归一化的结果确定所述输出蒸馏损失。
20、在该种可能的实现方式中,提出了另一种确定输出蒸馏损失的具体实现方法,提升了方案的可实施性。此外,通过计算输出蒸馏损失,可在后续基于该输出蒸馏损失对学生模型的损失函数进行改进,以使训练得到的学生模型具有更好的目标检测性能。
21、在一种可能的实现方式中,上述依据特征蒸馏损失及输出蒸馏损失改进学生模型的损失函数,包括:依据学生模型的改进前的检测损失、输出蒸馏损失以及特征蒸馏损失,算学生模型的改进后的损失函数。
22、在该种可能的实现方式中,可以实现学生模型改进损失函数的计算,以便基于改进后的损失函数指导学生模型的训练,以使学生模型具有更好的目标检测功能,有助于提升方案的可实施性。
23、第二方面,本技术实施例提供一种基于知识蒸馏的模型训练装置,该装置具有实现上述第一方面中任一项的基于知识蒸馏的模型训练方法的功能。该功能可以通过硬件实现,也可以通过硬件执行相应的软件实现。该硬件或软件包括一个或多个与上述功能相对应的模块。
24、第三方面,本技术提供一种电子设备,该电子设备包括存储器和处理器。上述存储器和处理器耦合。该存储器用于存储计算机程序代码,该计算机程序代码包括计算机指令。当处理器执行该计算机指令时,使得电子设备执行如第一方面及其任一种可能的设计方式所述的基于知识蒸馏的模型训练方法。
25、第四方面,本技术提供一种计算机可读存储介质,该计算机可读存储介质存储有计算机指令,当所述计算机指令在电子设备上运行时,使得电子设备执行如第一方面及其任一种可能的设计方式所述的基于知识蒸馏的模型训练方法。
26、第五方面,本技术提供一种计算机程序产品,该计算机程序产品包括计算机指令,当计算机指令在电子设备上运行时,使得电子设备执行如第一方面及其任一种可能的设计方式所述的基于知识蒸馏的模型训练方法。
27、本技术中第二方面到第五方面及其各种实现方式的具体描述,可以参考第一方面及其各种实现方式中的详细描述;并且,第二方面到第五方面及其各种实现方式的有益效果,可以参考第一方面及其各种实现方式中的有益效果分析,此处不再赘述。
28、本技术的这些方面或其他方面在以下的描述中会更加简明易懂。