本发明涉及机器学习领域,尤其涉及一种机器学习的训练数据选择方法。
背景技术:
近年来,机器学习,尤其是基于大规模深度神经网络的深度学习技术迅猛发展,已在生活的各个方面得到了应用。随着深度学习的日益流行,机器学习中的数据选择问题成为一个日益受关注的问题。如何自动地选择数据,提高深度学习模型的性能,成为目前的一个迫切的需求。
目前在机器学习数据选择的领域,已有了许多方法,例如将训练数据按照“难易程度”由低到高的所谓“课程”(curriculum)顺序训练,有利于模型的训练过程。此外,自步学习用数据的损失函数大小(lossvalue)作为“难易程度”的度量标准。在自步学习算法中,损失值大于一个特定阈值η的数据会被丢弃,而阈值η在训练过程中逐渐增长,直到最终所有数据都被选中。
然而,上述现有的数据选择策略属于人为定义的启发式策略,具有较大的特定性,由于不同的机器学习任务通常具有不同的数据分布和模型特点,这些规则在不同的机器学习任务上往往难以泛化。
技术实现要素:
基于现有技术所存在的问题,本发明的目的是提供一种机器学习的训练数据选择方法,能在机器学习的不同阶段根据当前训练状态动态地选择训练数据,进而提高机器学习模型的性能。
本发明的目的是通过以下技术方案实现的:
本发明实施方式提供一种机器学习的训练数据选择方法,包括以下步骤:
步骤1,选定待选择数据的机器学习模型,并获取该机器学习模型对应的训练数据集;
步骤2,从所述训练数据集中随机选出一个数据子集作为策略训练数据集,通过深度强化学习对所述策略训练数据集应用于所述机器学习模型进行若干轮训练,根据训练结果确定与所述机器学习模型匹配的数据选择策略;
步骤3,通过确定的所述数据选择策略对所述机器学习模型待输入数据按批次进行选择,将选出的数据用于所述机器学习模型的训练。
由上述本发明提供的技术方案可以看出,本发明实施例提供的机器学习的训练数据选择方法,其有益效果为:
通过深度强化学习对所述策略训练数据集应用于所述机器学习模型进行若干轮训练,根据训练结果确定与所述机器学习模型匹配的数据选择策略的方式,能得出对当前机器学习模型最优的训练数据选择策略,提升机器学习模型性能,由于不需要主动遍历所有未训练过的数据以选出用于训练的批次,降低了计算开销;并且由于不是对每个任务使用简单的启发式策略,对不同的学习任务能自适应的确定选择策略,实现为不同的学习任务最优化的选择训练数据。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他附图。
图1为本发明实施例提供的机器学习的训练数据选择方法的流程图;
图2为本发明实施例提供的选择方法中实验的mlp在mnist数据集的一半训练数据上不同数据选择策略的测试集准确率曲线图;
图3为本发明实施例提供的选择方法中实验的ndf策略在每一轮中过滤掉的数据数量;
图4为本发明实施例提供的选择方法中实验的resnet32在cifar-10数据集的一半训练数据上不同数据选择策略的测试集准确率曲线图;
图5为本发明实施例提供的选择方法中实验的ndf策略在每一轮中过滤掉的数据数量;
图6为本发明实施例提供的选择方法中实验的rnn在imdb的一半训练数据上不同数据选择策略的测试集准确率曲线图;
图7为本发明实施例提供的选择方法中实验的ndf策略在每一轮中过滤掉的数据数量。
具体实施方式
下面结合本发明的具体内容,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明的保护范围。本发明实施例中未作详细描述的内容属于本领域专业技术人员公知的现有技术。
如图1所示,本发明实施例提供一种机器学习的数据选择方法,是一种能在机器学习的不同阶段根据当前训练状态动态地选择训练数据的方法,进而可提高机器学习模型的性能,包括以下步骤:
步骤1,选定待选择数据的机器学习模型,并获取该机器学习模型对应的训练数据集;
步骤2,从所述训练数据集中随机选出一个数据子集作为策略训练数据集,通过深度强化学习对所述策略训练数据集应用于所述机器学习模型进行若干轮训练,根据训练结果确定与所述机器学习模型匹配的数据选择策略;
步骤3,通过确定的所述数据选择策略对所述机器学习模型待输入数据按批次进行选择,将选出的数据用于所述机器学习模型的训练。
上述方法的步骤2中,通过深度强化学习对所述策略训练数据集应用于所述机器学习模型进行若干轮训练,根据训练结果确定与所述机器学习模型匹配的数据选择策略为:
步骤21,将策略训练数据集分为两个不相交的策略训练子集和策略验证子集;
步骤22,初始化深度强化学习模型的策略函数;
步骤23,以所述策略训练子集作为训练数据,通过所述深度强化学习模型重复进行若干轮深度强化学习训练;
步骤24,训练完成后得到与所述机器学习模型匹配的深度强化学习模型的策略函数,该策略函数能为所述机器学习模型选择训练数据。
上述方法的步骤23中,每轮深度强化学习训练包括:
步骤231,初始化所述机器学习模型;
步骤232,用所述策略训练子集训练所述机器学习模型,直到所述机器学习模型停止训练;在所述机器学习模型的每步训练过程中,对于每批次数据,根据深度强化学习模型的输出动作,选定该批次数据中的一部分数据作为所述机器学习模型的输入,并使用策略验证子集,计算出该次训练对应的奖励函数值;
具体的,上述步骤232中,深度强化学习模型(即教师模块)对机器学习模型(即学生模块)的训练数据的选择交互方式是:先从机器学习模型中抽取状态特征向量,将抽取的状态特征向量作为输入给深度强化学习模型的策略函数进行处理,处理后得到输出动作,输出动作是指:在深度强化学习模型的数据选择任务中,对于一个批次中的每个数据,是选择保留该数据还是丢弃该数据的动作。
步骤233,当所述机器学习模型一轮训练结束后,累计计算所述奖励函数值,更新策略函数。
上述步骤233具体是:从一轮训练中所有的奖励函数计算累计回报,然后从累计回报中计算出梯度值,使用梯度值更新策略函数。上述过程是已有的深度强化学习算法(reinforce算法)的标准步骤。
上述方法的步骤232中,深度强化学习模型的输出动作中,包含以下特征中的至少一种:
(1)包含数据的信息的数据特征;
(2)包含反映当前模型训练进度的信息的基本模型特征。
(3)包含当前到达的训练数据对于当前模型的重要性的信息的模型与数据结合的特征。
上述输出动作中,
所述数据特征中包含数据的信息为:数据的标签类别信息、文本数据句子的长度、文字片段的语法信息、图像数据梯度直方图中的至少一种;
所述基本模型特征中包含反映当前模型训练进度的信息为:当前已训练的批次数量、历史的损失函数的平均值和训练集历史准确率;
所述模型与数据结合的特征中包含当前到达的训练数据对于当前模型的重要性的信息为:模型输出的每一类别的概率、数据的损失函数值和训练数据的间隔值;所述训练数据(x,y)的间隔值的定义为:p(y|x)-maxy′≠yp(y′|x))。
下面对本发明实施例具体作进一步地详细描述。
本发明提供一种基于深度强化学习算法的机器学习模型训练数据的选择方法。在该方法中,深度强化学习作为教师模块,而基本的机器学习模型则为学生模块,教师模块需要为学生模块选择合适的训练数据,该方法包括以下步骤:
步骤1,选择作为学生模块的机器学习模型,并收集相应的训练数据集;
步骤2,策略训练:从训练数据集中随机选出一个数据子集,称为策略数据集,用于教师模块;在该子集上,使用如下步骤,优化神经网络数据选择器策略网络:
步骤21,将策略数据集分为两个不相交的子集,分别用于训练和验证;
步骤22,初始化教师模块的策略函数;
步骤23,重复若干轮训练,在每轮训练中:
步骤231,初始化学生模型;
步骤232,用策略训练数据集训练学生模型,直到学生模型达到停止训练标准。在每步训练过程中,对于每批次数据,需要根据教师模块的输出的状态特征向量,选定其中一部分数据作为学生模型的输入。使用策略验证子集,计算下该动作对应的奖励函数值(奖励函数值具体见下文说明);
步骤233,学生模型训练结束后,计算累计奖励函数值,更新策略函数;
步骤24,训练完成,输出策略函数;
步骤3,策略应用:使用如下步骤,将优化好的教师模块应用到学生模块的训练中:
步骤31,将输入数据按批次通过数据选择器,选出应该被保留的数据;
步骤32,将被选出的数据用于学生模型的训练,训练过程与学生模型的原始训练过程一致。
该步骤3中的学生模型不必和步骤2中训练用的学生模型一致,可以使用迁移学习的方法,将该选择策略应用到其他机器学习任务的训练过程中,其具有泛化能力。
上述方法中,实现教师模块的策略训练和策略应用的算法中,策略训练的过程为一个马尔科夫决策过程,称为sgd-mdp。本发明中的“深度强化学习”是指基于sgd-mdp的策略训练算法,该策略训练算法属于深度强化学习的一种。
具体的,sgd-mdp:与经典的mdp一样,sgd-mdp由一个四元组<s;a;p;r>组成:
其中,s表示环境的状态,与当前训练批次的数据和当前机器学习模型的状态相关;
a表示动作。在数据选择任务中,a表示对于一个批次中的每个数据,是选择保留该数据还是丢弃该数据;
r=r(s;a)为奖励函数,可以被设置为任何能够指示训练进度的值,如验证集准确率、当前训练批次在模型更新前后的损失函数之差等。在计算累计期望时,未来的期望值会被乘以衰减因子γ(γ∈[0,1]);
本发明中的教师模块的核心是策略函数,策略函数为a=pθ(a|s),能随机抽样出动作a,其中θ为待学习的参数,策略函数a可以是任意的二分类模型,如logistic回归和深度神经网络等。以logistic回归为例,策略函数为a(s,a;θ)=pθ(a|s)=aσ(θf(s)+b)+(1-a)(1-σ(θf(s)+b)),此处σ(·)为sigmoid函数,策略参数θ={θ,b},f(s)为表示状态s的特征向量(即状态特征向量)。策略函数的具体结构不做具体限制。
(1)状态特征向量f(s)是使其能够简洁而有效地表示sgd-mdp的状态。由于状态s包含了当前到达的数据和当前基本模型的状态,使用以下三类特征组合成f(s)。
(2)数据特征:这类特征包含数据的信息,例如数据的标签类别信息、(对于文本数据)句子的长度、文字片段的语法信息以及(对于图像数据)梯度直方图特征等。上述数据特征也常用于课程学习,课程学习指对训练数据进行重排以提高模型性能的算法。
(3)基本模型特征:这类特征包含反映当前模型训练进度的信息。在实验中,使用1)当前已训练的批次数量;2)历史的损失函数的平均值;以及3)训练集历史准确率三个特征。
(4)模型与数据结合的特征:这类特征包含当前到达的训练数据对于当前模型的重要性的信息。在实验中,使用1)模型输出的每一类别的概率;2)数据的损失函数值;以及3)间隔值(marginvalue)三个特征。(训练数据(x,y)的间隔值定义为:p(y|x)-maxy′≠yp(y′|x))。
本发明的方法中的学生模块是已有的独立的机器学习模型,视为黑箱,不属于本发明的范围。
本发明的方法中,神经网络数据选择算法的全部训练流程如下2.1所示:
算法2.1:使用神经网络数据选择器的批量梯度下降算法
上述算法2.1中策略训练的详细算法如下:
算法2.2:ndf策略训练算法
本发明提出的方法优势体现在以下几点:
(1)通过神经网络的深度强化学习确定的选择策略在随机到达的训练批次上选择训练数据,由于不是主动遍历所有未训练过的数据以选出用于训练的批次,降低了计算开销;
(2)根据训练过程中返回的特征向量值自动得到匹配的最优数据选择策略,由于不是对每个任务使用简单的启发式策略,对不同的学习任务能自适应的确定选择策略,实现为不同的学习任务最优化的选择训练数据。
实验结果表明(参见图2、3、4、5、6、7),本发明的选择方法可以对训练过程的收敛速度有较显著的提高,说明这种基于强化学习的自适应算法是有效的,并且对于多种不同的机器学习任务有较好的泛化能力。
图2:mlp在mnist数据集的一半训练数据上不同数据选择策略的测试集准确率曲线图。包含如下超参数设置:ndf策略(即本发明的选择方法)中,验证集准确率阈值τ分别为0.93,0.95,0.97;spl策略中,s分别为80,120,160;randdrop使用ndfτ=0.97输出的选择数据的比例。x轴记录有效的训练数据的数量。图3:ndf策略(即本发明的选择方法)在每一轮中过滤掉的数据数量。不同的曲线表示不同困难程度的数据在训练过程中被过滤掉的数量,数据的困难程度以该数据在其所在批次中损失函数值的排名代表。具体而言,将大小为20的训练批次中{1,2,…,20}的排名值分入5个桶(buckets)。1号桶代表每个批次中最难的数据,其损失函数值最大(排名第1到第4),而5号桶表示每个批次中最简单的数据,其损失函数值最小。
图4:resnet32在cifar-10数据集的一半训练数据上不同数据选择策略的测试集准确率曲线图。ndf策略(即本发明的选择方法)中的超参数τ∈{0.80,0.84,0.88};spl策略中的超参数分别为s∈{120,150,180};randdrop策略使用ndfτ=0.84输出的选择数据的比例。图5:ndf策略(即本发明的选择方法)在每一轮中过滤掉的数据数量。与图3类似,将数据按排名值{1,2,…,128}分为5个桶,表示不同困难程度的数据在训练过程中被过滤掉的数量。
图6:rnn在imdb的一半训练数据上不同数据选择策略的测试集准确率曲线图。ndf策略(即本发明的选择方法)中的超参数τ∈{0.78,0.81,0.84};spl策略中的超参数分别为s∈{80,100,120};randdrop策略使用ndfτ=0.78输出的选择数据的比例。图7:ndf策略(即本发明的选择方法)在每一轮中过滤掉的数据数量。与图3类似,将数据按排名值{1,2,…,16}分为5个桶,表示不同困难程度的数据在训练过程中被过滤掉的数量。
本领域普通技术人员可以理解:实现上述实施例方法中的全部或部分流程是可以通过程序来指令相关的硬件来完成,所述的程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,所述的存储介质可为磁碟、光盘、只读存储记忆体(read-onlymemory,rom)或随机存储记忆体(randomaccessmemory,ram)等。
以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明披露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求书的保护范围为准。