本技术涉及人工智能领域,尤其涉及一种模型训练方法、图像识别方法、装置、设备及介质。
背景技术:
1、人工智能(artificial intelligence,ai)涵盖计算机视觉(computer vision,cv),cv技术也是ai的主要研究领域。cv技术可应用于图像分类任务,图像识别任务以及图像检索任务等,其中,在图像识别任务中,可利用训练好的分类模型判定它所属类别。
2、在图像识别任务中,增强模型识别鲁棒性的方案包括如下两种:方案1、在训练过程中对源域图像进行随机图像增强,分类模型同时从源域和对抗域图像学习识别能力;方案2、对该分类模型采用知识蒸馏的方案,即从具有更好图像识别能力的大模型中学习得到相应的图像识别能力,从而提高模型的识别鲁棒性。
3、但是上述两种方法均存在其对应的缺点:方案1会显著影响分类模型在源域的识别能力;方案2中大模型的训练成本较大。因此目前急需一种更合适的图像识别模型。
技术实现思路
1、本技术实施例提供了一种模型训练方法、图像识别方法、装置、设备及介质,用于降低训练成本,并提升图像识别模型的识别鲁棒性。
2、有鉴于此,本技术一方面提供一种模型训练方法,包括:获取初始学生模型、初始教师模型以及源域训练图像,其中,该初始学生模型与该初始教师模型具有相同的网络结构;基于该源域训练图像对该初始学生模型进行训练得到学生模型;基于对抗域训练图像对该初始教师模型进行迭代训练得到教师模型,并利用指数滑动平均根据该教师模型的网络参数迭代更新该学生模型的网络参数,以得到目标学生模型,该对抗域训练图像为该源域训练图像进行图像增强处理得到;在该教师模型的训练损失满足收敛条件时,输出该目标学生模型。
3、本技术另一方面提供一种模型训练装置,包括:获取模块,用于获取初始学生模型、初始教师模型以及源域训练图像,其中,该初始学生模型与该初始教师模型具有相同的网络结构;
4、处理模块,用于基于该源域训练图像对该初始学生模型进行训练得到学生模型;基于对抗域训练图像对该初始教师模型进行迭代训练得到教师模型,并利用指数滑动平均根据该教师模型的网络参数迭代更新该学生模型的网络参数,以得到目标学生模型,该对抗域训练图像为该源域训练图像进行图像增强处理得到;
5、输出模块,用于在该教师模型的训练损失满足收敛条件时,输出该目标学生模型。
6、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,该处理模块,用于基于第一训练子集对该初始教师模型进行训练得到第一教师模型,该第一教师模型具有第二网络参数,该第一训练子集包含于该对抗域训练图像;
7、利用指数滑动平均根据该第一网络参数更新该学生模型的第二网络参数,以得到该学生模型的第三网络参数,该第二网络参数为基于该源域训练图像训练该初始学生模型得到;
8、基于第二训练子集对该第一教师模型进行训练得到第二教师模型,该第二教师模型具有第四网络参数,该第二训练子集包含于该对抗域训练图像;
9、利用指数滑动平均根据该第四网络参数更新该第三网络参数得到该学生模型的第五网络参数;
10、重复上述操作,在训练损失满足收敛条件,得到该教师模型和该目标学生模型。
11、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,处理模块,用于获取该第一训练子集以及第一类中心矩阵,该第一训练子集包括第一样本图像和该第一样本图像对应的第一图像标签,该第一类中心矩阵用于指示该源域训练图像中对应的各个类别的特征中心;
12、调用该初始教师模型对该第一训练子集进行图像识别,以得到第一图像特征以及第一预测图像标签;
13、根据该第一预测图像标签与该第一图像标签进行损失计算得到第一损失值,并根据该第一图像特征与第一类中心向量进行距离度量得到第二损失值,该第一类中心向量为该第一预测图像标签所处类别的类中心向量,该第一类中心向量包含于该第一类中心矩阵;
14、根据该第一损失值和该第二损失值反向梯度传播更新该初始教师模型的网络参数,以得到该第一教师模型。
15、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,该处理模块,用于根据该第一图像特征和该第一类中心向量利用指数滑动平均更新该第一类中心矩阵,以得到第二类中心矩阵。
16、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,处理模块,用于获取该第二训练子集以及该第二类中心矩阵,该第二训练子集包括第二样本图像和该第二样本图像对应的第二图像标签;
17、调用该第一教师模型对该第二训练子集进行图像识别,以得到第二图像特征以及第二预测图像标签;
18、根据该第二预测图像标签与该第二图像标签进行损失计算得到第三损失值,并根据该第二图像特征与第二类中心向量进行距离度量得到第四损失值,该第二类中心向量为该第二预测图像标签所处类别的类中心向量,该第二类中心向量包含于该第二类中心矩阵;
19、根据该第三损失值和该第四损失值反向梯度传播更新该第一教师模型的网络参数,以得到该第二教师模型。
20、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,处理模块,用于对该源域训练图像进行图像增强处理生成该对抗域训练图像;
21、从该对抗域训练图像中进行采样得到该第一训练子集;
22、或者,
23、处理模块,用于从该源域训练图像中进行采样得到第一源域训练子集;
24、对该第一源域训练子集进行图像增强处理得到该第一训练子集。
25、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,处理模块,用于利用该学生模型对该源域训练图像进行前向计算,以得到该源域训练图像对应的图像特征;
26、根据该图像特征计算分布概率,以得到该源域训练图像的n个类别,该n为正整数;
27、获取该n个类别的n个特征中心向量;
28、根据该n个特征中心向量生成该第一类中心矩阵。
29、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,处理模块,用于利用该学生模型对该源域训练图像进行前向计算,以得到该源域训练图像对应的图像特征;
30、对该图像特征进行聚类计算,以得到该源域训练图像的n个类别,该n为正整数;
31、获取该n个类别的n个特征中心向量;
32、将该n个特征中心向量作为该第一类中心矩阵。
33、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,该处理模块,用于根据该第一预测图像标签与该第一图像标签进行交叉熵分类损失计算得到第一损失值;
34、或者,
35、根据该第一预测图像标签与该第一图像标签进行交叉熵分类损失计算得到第一损失值;
36、或者,
37、根据该第一预测图像标签与该第一图像标签进行逻辑回归损失计算得到第一损失值。
38、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,该处理模块,用于根据该第一图像特征与第一类中心向量进行均方误差mse损失计算得到第二损失值;
39、或者,
40、根据该第一图像特征与第一类中心向量进行平均绝对值误差l1损失计算得到第二损失值;
41、或者,
42、根据该第一图像特征与第一类中心向量进行l1-smooth损失计算得到第二损失值。
43、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,该处理模块,用于基于源域训练图像对该初始学生模型进行全监督训练得到该学生模型;
44、或者,
45、基于源域训练图像对该初始学生模型进行半监督训练得到该学生模型;
46、或者,
47、基于源域训练图像对该初始学生模型进行弱监督训练得到该学生模型;
48、或者,
49、基于源域训练图像对该初始学生模型进行无监督训练得到该学生模型。
50、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,该初始学生模型与该初始教师模型的网络结构为残差神经网络resnet、resnest、resnext、regnet、vgg、alexnet、transformer或者vit。
51、在一种可能的设计中,在本技术实施例的另一方面的另一种实现方式中,该教师模型与该学生模型采用相同的训练方式。
52、本技术另一方面提供一种图像识别方法,包括:获取待处理图像;
53、调用图像识别模型对该待处理图像进行识别处理,以得到该待处理图像的图像类别,该图像识别模型为采用上述任一项该的方法训练得到的目标学生模型;
54、输出该待处理图像的图像类别。
55、本技术的另一方面提供一种图像识别装置,包括:获取模块,用于获取待处理图像;
56、处理模块,用于调用图像识别模型对该待处理图像进行识别处理,以得到该待处理图像的图像类别,该图像识别模型为上述任一项该的目标学生模型;
57、输出模块,用于输出该待处理图像的图像类别。
58、本技术另一方面提供一种计算机设备,包括:存储器、处理器以及总线系统;
59、其中,存储器用于存储程序;
60、处理器用于执行存储器中的程序,处理器用于根据程序代码中的指令执行上述各方面的方法;
61、总线系统用于连接存储器以及处理器,以使存储器以及处理器进行通信。
62、本技术的另一方面提供了一种计算机可读存储介质,计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行上述各方面的方法。
63、本技术的另一个方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述各方面所提供的方法。
64、从以上技术方案可以看出,本技术实施例具有以下优点:提供一组网络结构相同的学生模型和教师模型,其中,学生模型只在源域进行训练,从而获取较好的源域识别能力;而教师模型只在对抗域训练,从而获取对抗识别能力;然后根据教师模型的网络参数通过指数平滑平均的方式更新该学生模型的网络参数,使得学生模型可以不断积累对抗识别能力,同时保留了源域的识别能力,最终使得学生模型具有较高的识别鲁棒性。同时,该学生模型与该教师模型采用相同的网络结构,不需要大模型训练和知识蒸馏过程,减少了模型训练复杂度,从而降低训练成本。