本发明涉及元学习算法,具体涉及一种改进的元学习算法mg-reptile。
背景技术:
1、reptile算法可以使用一阶信息在每个任务上执行梯度下降,以解决基于元学习的模型更新问题,而不必像模型不可知元学习(model agnostic meta-learning,maml)算法那样计算两次微分。即,reptile算法可以使用更少的内存来实现类似于maml的性能。
2、现有的基于元学习的模型虽然可以在小样本下很好地学习新任务,但它也需要足够的训练样本来构建模型,而使用少量训练样本构建模型可能会导致元过度拟合。此外,现有基于元学习的工作通常需要在训练用户和新用户的样本分布之间具有一定的相似性,其泛化能力较弱,即,当预测样本的分布差异较大时,利用现有的方法构建的模型的性能会有所下降。
技术实现思路
1、本发明的目的是提供一种改进的元学习算法mg-reptile,以解决基于少量训练样本的基础模型泛化性能低的问题,而且能够有效预测分布差异较大的新样本。
2、为实现上述目的,本发明提供了如下技术方案:
3、本发明提供了一种改进的元学习算法mg-reptile,包括:
4、构建基础模型;
5、输入训练样本和新样本;
6、引入分布式测量策略dms,对所述基础模型进行优化;
7、使用生成性对抗网络gan对所述基础模型进行优化。
8、优选的,所述引入分布式测量策略dms,对所述基础模型进行优化,具体包括:
9、初始化网络参数和外循环次数;
10、在内循环中更新所述网络参数;
11、将每个用户的输入数据划分为k个小批量,进行k次迭代,计算最大平均偏差mmd;
12、增加第(k+1)次迭代;
13、计算所有迭代的参数的平均值并更新外环中的网络初始化参数。
14、优选的,所述计算最大平均偏差mmd的公式为:
15、式中,xi和yi表示两个长度分别为m的不同小批量上的两个实例。
16、优选的,所述更新外环中的网络初始化参数的公式为:
17、式中,表示第i次迭代中的更新参数,ε表示学习率。
18、优选的,所述使用生成性对抗网络gan对所述基础模型进行优化,具体包括:
19、修改鉴别器,其中,分类部分和鉴别部分共享部分网络,所述鉴别部分用于估计新用户和训练用户之间的样本分布距离;
20、利用所述鉴别器的分类信息来约束模型优化方向,进一步计算类别层次上的分布距离;
21、根据所述类别层次上的分布距离,预测输入样本的类别;
22、根据所述新用户和训练用户之间的样本分布距离和所述计算类别层次上的分布距离,优化参数。
23、优选的,所述新用户和训练用户之间的样本分布距离的计算公式为:
24、式中,ldw是鉴别器的判别部分fdw的距离测量函数,fg表示特征提取,xs是训练样本,xt是新用户样本,ns是训练样本的数目,nt是新用户样本的数目,lgrad是梯度惩罚损失函数。
25、优选的,所述类别层次上的分布距离的计算公式为:
26、式中,ldc是鉴别器分类部分的距离度量函数,是从第i个训练样本的真实标签转换而来的独热编码向量,是鉴别器对第i个训练样本的预测输出,和是从真实标签和新用户第i个样本的鉴别器的预测输出转换而来的独热编码向量。
27、优选的,所述预测输入样本的类别的计算公式为:
28、式中,lc是总体分类函数,是第i个训练样本的预测输出,是第i个新用户样本的预测输出。
29、优选的,所述优化参数的计算公式为:
30、式中,θc是整体分类器的参数,θg是特征提取器的参数,θdc和θdw是鉴别器的参数,α和γ是权重参数,ldw和ldc是鉴别器的函数,lgrad是梯度惩罚损失函数。
31、因此,采用本发明提供的技术方案,通过引入分布式测量策略dms和使用生成性对抗网络gan对基础模型进行优化,以解决现有的基础元学习模型存在的基于少量训练样本的基础模型泛化性能低的问题,进而能够有效预测分布差异较大的新样本。
1.一种改进的元学习算法mg-reptile,其特征在于,包括:
2.根据权利要求1所述的改进的元学习算法mg-reptile,所述引入分布式测量策略dms,对所述基础模型进行优化,具体包括:
3.根据权利要求2所述的改进的元学习算法mg-reptile,所述计算最大平均偏差mmd的公式为:
4.根据权利要求2所述的改进的元学习算法mg-reptile,其特征在于,所述更新外环中的网络初始化参数的公式为:
5.根据权利要求1所述的改进的元学习算法mg-reptile,所述使用生成性对抗网络gan对所述基础模型进行优化,具体包括:
6.根据权利要求5所述的改进的元学习算法mg-reptile,所述新用户和训练用户之间的样本分布距离的计算公式为:
7.根据权利要求5所述的改进的元学习算法mg-reptile,所述类别层次上的分布距离的计算公式为:
8.根据权利要求5所述的改进的元学习算法mg-reptile,所述预测输入样本的类别的计算公式为:
9.根据权利要求5所述的改进的元学习算法mg-reptile,所述优化参数的计算公式为: