本技术涉及人工智能(artificial intelligence,ai),尤其涉及一种多模态模型的训练方法、训练装置以及训练设备。
背景技术:
1、多模态模型是一种可以用于处理多种模态的数据的ai模型。多种模态的数据是指多种不同类型的数据,例如可以包括文本、图像、声音或者视频等类型的数据。
2、在相关技术中,通过将多种模态的数据输入多模态模型中各个模态的网络进行对比训练。并且,基于各个模态的网络的输出层输出的特征更新各个模态的网络的参数。但由于多种模态的数据并非严格一一对应,存在一定的噪声,这会导致多模态模型的泛化性能较差,准确率较低。
3、以图像数据和文本数据两种模态为例,当多个图像数据对应的文本数据之间的相似度较大、或者多个文本数据对应的图像数据之间的相似度较大时,图像数据和图像数据对应的文本数据不是严格的一一对应。如果基于图像网络的输出层特征和文本网络的输出层特征更新图像网络的参数和文本网络的参数,多模态模型可能无法准确的找到图像和文本之间的最佳匹配,多模态模型的性能较差。
技术实现思路
1、本技术提供了一种多模态模型的训练方法、训练装置以及训练设备,根据多模态模型中各个模态的网络的中间层输出的特征更新各个模态的网络的参数,可以提高模型的性能。
2、第一方面,本技术提供一种多模态模型的训练方法。该多模态模型包括第一网络和第二网络,该第一网络分别用于处理不同模态的数据,该第一网络包括多个第一中间层,该第二网络包括多个第二中间层。
3、该方法包括:将第一数据和第二数据分别输入该第一网络和该第二网络,该第一数据和该第二数据为不同模态的数据;获取该多个第一中间层输出的一个或多个第一特征和该多个第二中间层输出的一个或多个第二特征;将该一个或多个第一特征和该一个或多个第二特征进行融合,得到融合特征;至少基于该融合特征确定第三损失值;根据该第三损失值对该多模态模型的参数进行更新。其中,该第三损失值用于拉近该多模态模型的预测值和真实值之间的偏差。
4、上述方案在训练多模态模型时,基于融合两个网络的中间层输出的特征得到的融合特征,确定用于更新多模态模型的参数的损失值,可以解决因各个模态的数据非严格地一一对应,导致基于输出层输出的特征训练的多模态模型的性能较差的问题。
5、在第一方面的一种可能的实施方式中,该至少基于该融合特征确定第三损失值包括:至少基于该多个第二中间层输出的至少一个第二特征和该融合特征确定该第三损失值。具体地,可以根据融合特征与图像数据的误差、以及至少一个第二特征与文本数据的误差确定该第三损失值;或者,还可以根据融合特征与文本数据的误差、以及至少一个第二特征与文本数据的误差确定该第三损失值。
6、在第一方面的一种可能的实施方式中,该至少一个第二特征不包含于用于融合的该一个或多个第二特征中。
7、在第一方面的一种可能的实施方式中,该方法还包括:基于该多个第一特征中的至少一个第一特征和该多个第二特征中的至少一个第二特征确定第二损失值,该第二损失值用于拉近该至少一个第一特征和该至少一个第二特征之间的相似度;根据该第二损失值对该多模态模型的参数进行更新。其中,根据第三损失值更新参数和根据第二损失值更新参数可以同时执行,也可以依次执行。
8、上述方案在训练过程中综合考虑两个网络的中间层输出的特征的相似度,可以进一步提升模型的性能。
9、在第一方面的一种可能的实施方式中,该方法还包括:获取该第一网络的输出层输出的第一输出结果和该第二网络的输出层输出的第二输出结果;根据该第一输出结果和该第二输出结果确定第一损失值;根据该第一损失值对该多模态模型的参数进行更新。该第一损失值用于拉近该第一输出结果和该第二输出结果之间的相似度。根据第一损失值更新参数、根据第二损失值更新参数、以及根据第三损失值更新参数可以同时执行,也可以依次执行。
10、上述方案中,在训练过程中,综合考虑两个网络的输出层输出的特征的相似度,可以进一步提升模型的性能。
11、在第一方面的一种可能的实施方式中,在该将第一数据和第二数据分别输入该第一网络和该第二网络之前,该方法还包括:对该第一数据和/或该第二数据进行掩码处理。
12、上述方案中,进行掩码处理可以最大程度挖掘第一网络或第二网络的潜力,可以提升模型的泛化性能,提高模型性能。
13、在第一方面的一种可能的实施方式中,该第一特征包括多个第一单位特征,该第二特征包括多个第二单位特征,该基于该多个第一特征中的至少一个第一特征和该多个第二特征中的至少一个第二特征确定第二损失值包括:根据该至少一个第一特征对应的多个第一单位特征与该至少一个第二特征对应的多个第二单位特征确定该至少一个第一特征与该至少一个第二特征之间的第一相似度;根据该第一相似度与匹配标签值的差值确定该第二损失值,该匹配标签值表征该第一数据与该第二数据之间的相似度。
14、上述方案中,根据两个网络的中间层输出的特征中的单位特征计算两个网络的中间层输出的特征之间的相似度,可以细粒度地对齐两种模态数据,提高两种模态数据的对齐精度,从而提升模型的性能。
15、在第一方面的一种可能的实施方式中,该根据该第一输出结果和该第二输出结果确定第一损失值包括:确定该第一输出结果和该第二输出结果之间的第二相似度;调整匹配标签值,该匹配标签值表征该第一数据与该第二数据之间的相似度;根据调整后的该匹配标签值与该第二相似度的差值确定该第一损失值。
16、上述方案中,通过调整两种模态数据的匹配标签值,可以有效缓解两种模态数据之间非严格一一对应的噪声影响,提升模型性能。
17、在第一方面的一种可能的实施方式中,所述第一数据包括图像数据,所述第二数据包括文本数据,所述第一网络包括图像编码器,所述第二网络包括文本编码器。
18、第二方面,本技术提供一种多模态模型的训练装置。该多模态模型包括第一网络和第二网络,该第一网络分别用于处理不同模态的数据,该第一网络包括多个第一中间层,该第二网络包括多个第二中间层,该装置包括:获取模块和更新模块。
19、其中,获取模块用于将第一数据和第二数据分别输入该第一网络和该第二网络,获取该多个第一中间层输出的一个或多个第一特征和该多个第二中间层输出的一个或多个第二特征,该第一数据和该第二数据为不同模态的数据。
20、其中,更新模块用于将该一个或多个第一特征和该一个或多个第二特征进行融合,得到融合特征,以及更新模块用于至少基于该融合特征确定第三损失值,根据该第三损失值对该多模态模型的参数进行更新,该第三损失值用于拉近该多模态模型的预测值和真实值之间的偏差。
21、在第二方面的一种可能的实施方式中,该更新模块具体用于:至少基于该多个第二中间层输出的至少一个第二特征和该融合特征确定该第三损失值。
22、在第二方面的一种可能的实施方式中,该至少一个第二特征不包含于用于该融合的该一个或多个第二特征中。
23、在第二方面的一种可能的实施方式中,该更新模块还用于:基于该多个第一特征中的至少一个第一特征和该多个第二特征中的至少一个第二特征确定第二损失值,该第二损失值用于拉近该至少一个第一特征和该至少一个第二特征之间的相似度;根据该第二损失值对该多模态模型的参数进行更新。
24、在第二方面的一种可能的实施方式中,该更新模块还用于:根据该第一网络的输出层输出的第一输出结果和该第二网络的输出层输出的第二输出结果确定第一损失值;根据该第一损失值对该多模态模型的参数进行更新。该第一损失值用于拉近该第一输出结果和该第二输出结果之间的相似度。
25、在第二方面的一种可能的实施方式中,在该将第一数据和第二数据分别输入该第一网络和该第二网络之前,该获取模块还用于:对该第一数据和/或该第二数据进行掩码处理。
26、在第二方面的一种可能的实施方式中,该第一特征包括多个第一单位特征,该第二特征包括多个第二单位特征,该更新模块具体用于:根据该至少一个第一特征对应的多个第一单位特征与该至少一个第二特征对应的多个第二单位特征确定该至少一个第一特征与该至少一个第二特征之间的第一相似度;根据该第一相似度与匹配标签值的差值确定该第二损失值,该匹配标签值表征该第一数据与该第二数据之间的相似度。
27、在第二方面的一种可能的实施方式中,该更新模块还用于:确定该第一输出结果和该第二输出结果之间的第二相似度;调整匹配标签值,该匹配标签值表征该第一数据与该第二数据之间的相似度;根据调整后的该匹配标签值与该第二相似度的差值确定该第一损失值。
28、在第二方面的一种可能的实施方式中,该第一数据包括图像数据,该第二数据包括文本数据,该第一网络包括图像编码器,该第二网络包括文本编码器。
29、第三方面,本技术提供一种训练设备。训练设备包括处理器和存储器。处理器用于执行存储于存储器内的计算机程序以实现前述第一方面或第一方面的任意一种可能的实现方式所提供的训练方法。
30、第四方面,本技术提供一种计算机可读存储介质,计算机可读存储介质内存储有指令,当其在计算机上运行时,使得计算机执行上述第一方面或第一方面的任意一种可能的实现方式所提供的训练方法。
31、第五方面,本技术提供一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述第一方面或第一方面的任意一种可能的实现方式所提供的训练方法。
32、上述提供的任一种装置或训练设备或计算机存储介质或计算机程序产品,均用于执行上文所提供的方法,因此,其所能达到的有益效果可参考上文提供的对应方法中的对应方案的有益效果,此处不再赘述。