本申请涉及自动驾驶,特别涉及一种任务模型的训练方法、装置、电子设备及存储介质。
背景技术:
1、预训练指将大量收集的训练数据放在一起,经过某种预训方法去学习其中的共性,然后将其中的共性“移植”到特定任务的模型中去,再使用相关特定领域的少量标注数据进行“微调”,从而会加速训练,使得模型收敛更快。
2、然而,相关技术中通过其他开源训练模型对模型进行训练或者利用大量的标注数据进行模型训练,导致训练方法复杂,训练效果较差。
技术实现思路
1、本申请提供一种任务模型的训练方法、装置、电子设备及存储介质,以解决相关技术中在训练模型中没有很好利用大规模的无监督数据,训练方法复杂,训练效果较差等问题。
2、本申请第一方面实施例提供一种任务模型的训练方法,包括以下步骤:获取待训练的目标任务模型;识别预先训练得到的预训练模型与所述目标任务模型的不一致模型参数,并丢弃所述不一致模型参数中不满足预设条件的模型参数;将所述预训练模型中剩余模型参数导入所述目标任务模型进行初始化,并利用预设标注数据集对所述目标任务模型进行训练,直到满足训练结束条件,得到训练完成的目标任务模型。
3、根据上述技术手段,本申请实施例可以通过将预训练模型中与目标任务模型中不一致的参数去掉,将合适的参数导入目标任务模型进行初始化,使用预先标注好的数据集对目标任务模型进行训练,提升训练效果。
4、进一步地,所述预训练模型基于无标签数据集训练得到,包括:获取所述无标签数据集;将所述无标签数据集输入预先构建的解码器,得到所述无标签数据集的特征向量,利用所述特征向量对预先构建的栈式自编码器模型进行训练;利用所述解码器对所述特征向量进行重构得到重构数据集,根据所述无标签数据集与所述重构数据集计算训练损失值,并利用所述训练损失值对所述栈式自编码器模型进行训练,直到满足训练停止条件时,停止训练,并得到训练完成的编码器参数和解码器参数,基于所述训练完成的编码器参数和解码器参数得到所述预训练模型。
5、根据上述技术手段,本申请实施例可以利用无标签数据集较低成本的获得预训练的编码器模型,无需使用人工标注的标签信息,训练效率高,从而提升模型的整体表现。
6、进一步地,所述满足训练停止条件,包括:获取预设标签数据;从所述预设标签数据中随机选择数据作为锚点数据,并将所述预设标签数据输入所述预训练模型,得到锚点数据、与所述锚点数据的同类数据和异类数据的数据特征;根据所述数据特征分别计算所述锚点数据与所述同类数据之间的第一距离、以及所述锚点数据与所述异类数据之间的第二距离;若所述第一距离小于或等于第一阈值、且所述第二距离均大于第二阈值,则判定所述预训练模型的特征提取效果满足期望效果,并确定满足训练停止条件,否则判定不满足训练停止条件。
7、根据上述技术手段,本申请实施例可以使用小规模标签数据对模型的特征提取效果进行评估,确定是否满足训练停止条件,提高了训练效果。
8、进一步地,所述获取待训练的目标任务模型,包括:识别所述预训练模型的网络结构;根据目标任务的任务需求和所述网络结构建立所述目标任务模型。
9、根据上述技术手段,本申请实施例可以通过建立一个与预训练模型类似的特征提取网络的目标任务模型,从而对目标任务模型进行训练,获取最终的任务模型。
10、本申请第二方面实施例提供一种任务模型的训练装置,包括:获取模块,用于获取训练的目标任务模型;识别模块,用于识别预先训练得到的预训练模型与所述目标任务模型的不一致模型参数,并丢弃所述不一致模型参数中不满足预设条件的模型参数;训练模块,用于将所述预训练模型中剩余模型参数导入所述目标任务模型进行初始化,并利用预设标注数据集对所述目标任务模型进行训练,直到满足训练结束条件,得到训练完成的目标任务模型。
11、进一步地,还包括:预训练模块,用于获取所述无标签数据集;将所述无标签数据集输入预先构建的解码器,得到所述无标签数据集的特征向量,利用所述特征向量对预先构建的栈式自编码器模型进行训练;利用所述解码器对所述特征向量进行重构得到重构数据集,根据所述无标签数据集与所述重构数据集计算训练损失值,并利用所述训练损失值对所述栈式自编码器模型进行训练,直到满足训练停止条件时,停止训练,并得到训练完成的编码器参数和解码器参数,基于所述训练完成的编码器参数和解码器参数得到所述预训练模型。
12、进一步地,所述预训练模块进一步用于:获取预设标签数据;从所述预设标签数据中随机选择数据作为锚点数据,并将所述预设标签数据输入所述预训练模型,得到锚点数据、与所述锚点数据的同类数据和异类数据的数据特征;根据所述数据特征分别计算所述锚点数据与所述同类数据之间的第一距离、以及所述锚点数据与所述异类数据之间的第二距离;若所述第一距离小于或等于第一阈值、且所述第二距离均大于第二阈值,则判定所述预训练模型的特征提取效果满足期望效果,并确定满足训练停止条件,否则判定不满足训练停止条件。
13、进一步地,所述获取模块进一步用于:识别所述预训练模型的网络结构;根据目标任务的任务需求和所述网络结构建立所述目标任务模型。
14、本申请第三方面实施例提供一种电子设备,包括:存储器、处理器及存储在所述存储
15、器上并可在所述处理器上运行的计算机程序,所述处理器执行所述程序,以实现如上述实5施例所述的任务模型的训练方法。
16、本申请第四方面实施例提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行,以用于实现如上述实施例所述的任务模型的训练方法。
17、由此,本申请至少具有如下有益效果:
18、(1)本申请实施例可以通过将预训练模型中与目标任务模型中不一致的参数去掉,将0合适的参数导入目标任务模型进行初始化,使用预先标注好的数据集对目标任务模型进行
19、训练,提升训练效果。
20、(2)本申请实施例可以利用无标签数据集较低成本的获得预训练的编码器模型,无需使用人工标注的标签信息,训练效率高,从而提升模型的整体表现。
21、(3)本申请实施例可以使用小规模标签数据对模型的特征提取效果进行评估,确定是5否满足训练停止条件,提高了训练效果。
22、(4)本申请实施例可以通过建立一个与预训练模型类似的特征提取网络的目标任务模型,从而对目标任务模型进行训练,获取最终的任务模型。
23、本申请附加的方面和优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本申请的实践了解到。
1.一种任务模型的训练方法,其特征在于,包括以下步骤:
2.根据权利要求1所述的方法,其特征在于,所述预训练模型基于无标签数据集训练得到,包括:
3.根据权利要求2所述的方法,其特征在于,所述满足训练停止条件,包括:
4.根据权利要求1所述的方法,其特征在于,所述获取待训练的目标任务模型,包括:
5.一种任务模型的训练装置,其特征在于,包括:
6.根据权利要求5所述的装置,其特征在于,还包括:
7.根据权利要求6所述的装置,其特征在于,所述预训练模块进一步用于:
8.根据权利要求5所述的装置,其特征在于,所述获取模块进一步用于:
9.一种电子设备,其特征在于,包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述程序,以实现如权利要求1-4任一项所述的任务模型的训练方法。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行,以用于实现如权利要求1-4任一项所述的任务模型的训练方法。