基于元学习优化方法的连续学习方法及装置与流程

文档序号:20757011发布日期:2020-05-15 17:30阅读:264来源:国知局
基于元学习优化方法的连续学习方法及装置与流程

本发明涉及连续学习/元学习/优化算法技术领域,特别涉及一种基于元学习优化方法的连续学习方法及装置。



背景技术:

在统计机器学习系统和深度学习系统投入使用时,要求模型学习到的输入数据和输出数据的分布和要进行测试的分布是一致的,同时也要考虑输入数据、输出数据的分布随时间的变化。在单一任务的机器学习场景下,统计机器学习模型或深度神经网络模型的可行性和有效性严重取决于一个假设:输入和输出数据的分布不会随时间发生显著变化;否则会使得模型原本学习到的模式和复杂关系表现不佳或甚至完全不可用。但在真实世界中,这样的假设条件很少能够成立。以人类的认知学习为例,人类大脑在学习过程中,能够应对输入数据随时间变化的显著性和多样性,这是目前针对单一任务的机器学习模型不能够适应的。因此,神队神经网络的连续学习问题得到了越来越多的关注。

连续学习场景不同于传统单一任务的机器学习场景,也不同于多任务同时学习的场景。在连续学习的场景中,任务按照时间顺序到达,当前任务的训练过程结束后,继续在同一模型上训练下一个任务,已经结束的任务的数据是不能获得的,即我们不能获得已经结束的任务的数据分布。在测试阶段,连续学习场景下的模型要完成所有任务的测试,并保证在所有任务上的表现比较好。

解决连续学习问题的难点在于学习新任务的同时尽可能少地遗忘旧任务。现有技术主要通过约束模型参数变化程度、增加记忆储存单元等方法解决遗忘旧任务的问题。现有技术的缺点时通过约束模型参数变化程度一方面会影响新任务的学习效果,另一方面解决遗忘问题的效果不好;通过增加记忆储存单元的方法会增加额外的储存空间,当储存空间有限时,模型解决遗忘问题的效果不好。



技术实现要素:

本发明旨在至少在一定程度上解决相关技术中的技术问题之一。

为此,本发明的一个目的在于提出一种基于元学习优化方法的连续学习方法,该方法可以有效解决目前连续学习场景下,机器学习模型对旧任务的遗忘,以及多任务随时间先后到达时,模型在旧任务上学习的知识如何迁移到新任务学习的问题。

本发明的另一个目的在于提出一种基于元学习优化方法的连续学习装置。

为达到上述目的,本发明一方面实施例提出了一种基于元学习优化方法的连续学习方法,包括以下步骤:建立初始深度学习模型;对于第t个新任务,利用元学习的优化方法训练所述初始深度学习模型;在每次训练任务结束后,更新记忆储存单元,储存部分训练数据以用于后续学习,且在模型训练结束后,使得最终深度学习模型根据任意输入数据执行不同的任务。

本发明实施例的基于元学习优化方法的连续学习方法,利用元学习的优化方法解决连续学习场景下的机器学习问题,且元学习优化方法能够在样本量少的情况下达到较好的学习效果,结合少量记忆储存单元,可以有效解决连续学习中的遗忘问题;此外,元学习优化方法能够充分学习不同任务之间的关系,提高多任务之间的迁移效率,从而提高模型的学习能力。

另外,根据本发明上述实施例的基于元学习优化方法的连续学习方法还可以具有以下附加的技术特征:

进一步地,在本发明的一个实施例中,所述利用元学习的优化方法训练所述初始深度学习模型,包括:利用输入数据得到当前模型的梯度,在当前梯度方向下降预设步;获取梯度下降后的模型参数的新梯度,并根据所述新梯度更新原来的模型参数。

进一步地,在本发明的一个实施例中,所述获取梯度下降后的模型参数的新梯度,并根据所述新梯度更新原来的模型参数,进一步包括:将所述输入数据分为训练数据和验证数据;使用所述训练数据在当前模型参数上进行预设步梯度下降,得到所述梯度下降后的模型参数的新梯度;根据所述验证数据得到所述新梯度的梯度下降方向,使用所述新梯度的梯度下降方向,原来的模型参数上进行一步梯度下降,以更新原来的模型参数。

进一步地,在本发明的一个实施例中,所述初始深度学习模型使用卷积神经网络模型,其中,所述卷积神经网络模型有m层,每层有预设数目的卷积核,以及归一化操作、非线性激活函数,所述卷积神经网络模型的输入数据是图片,输出是图片类别。

进一步地,在本发明的一个实施例中,在每次训练任务结束后,更新记忆储存单元,储存部分训练数据以用于后续学习,包括:判断所述记忆储存单元的当前容量是否小于存储第t个新任务的数据所需的存储容量;如果小于,则将记忆储存单元的容量平均分给t个任务,随机删除每个任务中多余的数据,直到储存第t个新任务的数据;否则,随机选择所述第t个新任务的任意部分数据,储存在记忆储存单元。

为达到上述目的,本发明另一方面实施例提出了一种基于元学习优化方法的连续学习装置,包括:建立模块,用于建立初始深度学习模型;训练模块,用于对于第t个新任务,利用元学习的优化方法训练所述初始深度学习模型;更新模块,用于在每次训练任务结束后,更新记忆储存单元,储存部分训练数据以用于后续学习,且在模型训练结束后,使得最终深度学习模型根据任意输入数据执行不同的任务。

本发明实施例的基于元学习优化方法的连续学习装置,利用元学习的优化方法解决连续学习场景下的机器学习问题,且元学习优化方法能够在样本量少的情况下达到较好的学习效果,结合少量记忆储存单元,可以有效解决连续学习中的遗忘问题;此外,元学习优化方法能够充分学习不同任务之间的关系,提高多任务之间的迁移效率,从而提高模型的学习能力。

另外,根据本发明上述实施例的基于元学习优化方法的连续学习装置还可以具有以下附加的技术特征:

进一步地,在本发明的一个实施例中,所述训练模块进一步用于利用输入数据得到当前模型的梯度,在当前梯度方向下降预设步;获取梯度下降后的模型参数的新梯度,并根据所述新梯度更新原来的模型参数。

进一步地,在本发明的一个实施例中,所述训练模块具体用于将所述输入数据分为训练数据和验证数据;使用所述训练数据在当前模型参数上进行预设步梯度下降,得到所述梯度下降后的模型参数的新梯度;根据所述验证数据得到所述新梯度的梯度下降方向,使用所述新梯度的梯度下降方向,原来的模型参数上进行一步梯度下降,以更新原来的模型参数。

进一步地,在本发明的一个实施例中,所述初始深度学习模型使用卷积神经网络模型,其中,所述卷积神经网络模型有m层,每层有预设数目的卷积核,以及归一化操作、非线性激活函数,所述卷积神经网络模型的输入数据是图片,输出是图片类别。

进一步地,在本发明的一个实施例中,所述更新模块进一步用于判断所述记忆储存单元的当前容量是否小于存储第t个新任务的数据所需的存储容量;如果小于,则将记忆储存单元的容量平均分给t个任务,随机删除每个任务中多余的数据,直到储存第t个新任务的数据;否则,随机选择所述第t个新任务的任意部分数据,储存在记忆储存单元。

本发明附加的方面和优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本发明的实践了解到。

附图说明

本发明上述的和/或附加的方面和优点从下面结合附图对实施例的描述中将变得明显和容易理解,其中:

图1为根据本发明实施例的基于元学习优化方法的连续学习方法的流程图;

图2为根据本发明一个实施例的基于元学习优化方法的连续学习方法的流程图;

图3为根据本发明实施例的基于元学习优化方法的连续学习装置的结构示意图。

具体实施方式

下面详细描述本发明的实施例,所述实施例的示例在附图中示出,其中自始至终相同或类似的标号表示相同或类似的元件或具有相同或类似功能的元件。下面通过参考附图描述的实施例是示例性的,旨在用于解释本发明,而不能理解为对本发明的限制。

下面参照附图描述根据本发明实施例提出的基于元学习优化方法的连续学习方法及装置,首先将参照附图描述根据本发明实施例提出的基于元学习优化方法的连续学习方法。

图1是本发明一个实施例的基于元学习优化方法的连续学习方法的流程图。

如图1所示,该基于元学习优化方法的连续学习方法包括以下步骤:

在步骤s101中,建立初始深度学习模型。

可以理解的是,首先建立深度学习模型,可以选择卷积神经网络、全连接网络等模型,输入数据一般是图片、视频等形式,输出根据任务的不同有所变化。

具体而言,深度学习模型使用卷积神经网络模型,卷积神经网络模型有m层,每层有一定数目的卷积核,以及归一化操作、非线性激活函数。深度网络的输入数据是图片,输出是图片类别。

在步骤s102中,对于t个新任务,利用元学习的优化方法训练初始深度学习模型。

其中,如图2所示,元学习优化方法的具体训练步骤是:利用输入数据得到当前模型的梯度,在当前梯度方向下降预设步(本领域技术人员可以根据实际情况进行设置,在此不做具体限定),再计算梯度下降后的模型参数的梯度,使用新梯度更新原来的模型参数。

可以理解的是,对于第一个任务,利用元学习的优化方法训练模型,使得模型在当前任务上效果好。同样的,对于第二个以及之后的新任务,当第t个新任务的数据到达后,先从储存单元中随机选择n个任务的数据(n<t),即旧任务的数据。使用n个旧任务的数据以及新任务的数据更新模型。具体更新方法是,对于每一个任务,根据输入数据在当前梯度方向下降预设步,得到n+1个更新后的新参数,再根据n+1个参数更新原来的模型参数。

进一步地,在本发明的一个实施例中,获取梯度下降后的模型参数的新梯度,并根据新梯度更新原来的模型参数,进一步包括:将输入数据分为训练数据和验证数据;使用训练数据在当前模型参数上进行预设步梯度下降,得到梯度下降后的模型参数的新梯度;根据验证数据得到新梯度的梯度下降方向,使用新梯度的梯度下降方向,原来的模型参数上进行一步梯度下降,以更新原来的模型参数。

可以理解的是,元学习的优化方法目标是学习模型训练的过程,本发明实施例利用元学习的优化方法进行模型训练,不同于传统随机梯度下降的模型训练方法,元学习优化方法的具体内容为训练模型时,将输入数据分为训练数据和验证数据,首先使用训练数据在当前模型参数上进行少量步梯度下降,得到新的参数;再根据验证数据,求出新的梯度下降方向;最后使用新的梯度下降方向,在更新前的模型参数上进行一步梯度下降,完成一次模型更新。

在步骤s103中,在每次训练任务结束后,更新记忆储存单元,储存部分训练数据以用于后续学习,且在模型训练结束后,使得最终深度学习模型根据任意输入数据执行不同的任务。

可以理解的是,每次训练任务结束后,都更新记忆储存单元,储存一部分训练数据用于后续学习;模型训练结束后,能够根据任意输入数据执行不同的任务。

进一步地,在本发明的一个实施例中,在每次训练任务结束后,更新记忆储存单元,储存部分训练数据以用于后续学习,包括:判断记忆储存单元的当前容量是否小于存储第t个新任务的数据所需的存储容量;如果小于,则将记忆储存单元的容量平均分给t个任务,随机删除每个任务中多余的数据,直到储存第t个新任务的数据;否则,随机选择第t个新任务的任意部分数据,储存在记忆储存单元。

可以理解的是,在第t个任务训练过程完成后,根据储存单元的容量,随机选择第t个任务的任意部分数据,储存在储存单元,用于之后的训练。具体更新记忆储存单元的方法是当记忆储存单元还有容量时,随机选择第t个任务的任意部分数据,储存在储存单元;如果记忆储存单元没有容量,则将记忆储存单元的容量平均分给前t个任务,随机删除每个任务中多余的数据,保证记忆储存单元足够储存第t个任务的新数据。

下面将通过具体示例对基于元学习优化方法的连续学习方法进行阐述。

在用户与机器的多任务交互场景中,用户对于任务的选择时随机的、多样的,例如,用户会输入来自多种场景、多样化内容的图片,图片类型的不同则表示数据对应不同的任务,机器要能够对多种任务进行响应。除了数据的多样化,任务目标也是多样化的,例如,用户输入一张图片,用户希望获得的输出是多样化的,机器可以输出图片的类别,可以输出对于图片的文字描述,也可以输出图片经过变换后的新图片,这就是任务目标的多样性。基于元学习优化方法的连续学习方法能够在学习旧任务的知识之后,更好地学习新任务,并且保证不忘记旧任务的知识,即模型的能力是不断增强的。在所有任务训练完成后,机器(模型)能够根据用户的输入自动判断用户选择的是什么任务,为用户提供相应的反馈。

综上,本发明实施例提出建立元学习方法和记忆储存单元相结合的框架,应用在连续学习的场景下,符合人类对于真实世界的认知过程;提出利用元学习优化方法,结合少量的记忆储存单元,解决连续学习中的灾难性遗忘问题;提出利用元学习优化方法,提高模型从旧任务到新任务的迁移能力。

根据本发明实施例提出的基于元学习优化方法的连续学习方法,利用元学习的优化方法解决连续学习场景下的机器学习问题,且元学习优化方法能够在样本量少的情况下达到较好的学习效果,结合少量记忆储存单元,可以有效解决连续学习中的遗忘问题;此外,元学习优化方法能够充分学习不同任务之间的关系,提高多任务之间的迁移效率,从而提高模型的学习能力。

其次参照附图描述根据本发明实施例提出的基于元学习优化方法的连续学习装置。

图3是本发明一个实施例的基于元学习优化方法的连续学习装置的结构示意图。

如图3所示,该基于元学习优化方法的连续学习装置10包括:建立模块100、训练模块200和更新模块300。

其中,建立模块100用于建立初始深度学习模型;训练模块200用于对于第t个新任务,利用元学习的优化方法训练初始深度学习模型;更新模块300用于在每次训练任务结束后,更新记忆储存单元,储存部分训练数据以用于后续学习,且在模型训练结束后,使得最终深度学习模型根据任意输入数据执行不同的任务。本发明实施例的装置10利用元学习的优化方法解决连续学习场景下的机器学习问题,且元学习优化方法能够充分学习不同任务之间的关系,提高多任务之间的迁移效率,从而提高模型的学习能力。

进一步地,在本发明的一个实施例中,训练模块200进一步用于利用输入数据得到当前模型的梯度,在当前梯度方向下降预设步;获取梯度下降后的模型参数的新梯度,并根据新梯度更新原来的模型参数。

进一步地,在本发明的一个实施例中,训练模块200具体用于将输入数据分为训练数据和验证数据;使用训练数据在当前模型参数上进行预设步梯度下降,得到梯度下降后的模型参数的新梯度;根据验证数据得到新梯度的梯度下降方向,使用新梯度的梯度下降方向,原来的模型参数上进行一步梯度下降,以更新原来的模型参数。

进一步地,在本发明的一个实施例中,初始深度学习模型使用卷积神经网络模型,其中,卷积神经网络模型有m层,每层有预设数目的卷积核,以及归一化操作、非线性激活函数,卷积神经网络模型的输入数据是图片,输出是图片类别。

进一步地,在本发明的一个实施例中,更新模块300进一步用于判断记忆储存单元的当前容量是否小于存储第t个新任务的数据所需的存储容量;如果小于,则将记忆储存单元的容量平均分给t个任务,随机删除每个任务中多余的数据,直到储存第t个新任务的数据;否则,随机选择第t个新任务的任意部分数据,储存在记忆储存单元。

需要说明的是,前述对基于元学习优化方法的连续学习方法实施例的解释说明也适用于该实施例的基于元学习优化方法的连续学习装置,此处不再赘述。

根据本发明实施例提出的基于元学习优化方法的连续学习装置,利用元学习的优化方法解决连续学习场景下的机器学习问题,且元学习优化方法能够在样本量少的情况下达到较好的学习效果,结合少量记忆储存单元,可以有效解决连续学习中的遗忘问题;此外,元学习优化方法能够充分学习不同任务之间的关系,提高多任务之间的迁移效率,从而提高模型的学习能力。

在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本发明的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不必须针对的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任一个或n个实施例或示例中以合适的方式结合。此外,在不相互矛盾的情况下,本领域的技术人员可以将本说明书中描述的不同实施例或示例以及不同实施例或示例的特征进行结合和组合。

此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括至少一个该特征。在本发明的描述中,“n个”的含义是至少两个,例如两个,三个等,除非另有明确具体的限定。

流程图中或在此以其他方式描述的任何过程或方法描述可以被理解为,表示包括一个或更n个用于实现定制逻辑功能或过程的步骤的可执行指令的代码的模块、片段或部分,并且本发明的优选实施方式的范围包括另外的实现,其中可以不按所示出或讨论的顺序,包括根据所涉及的功能按基本同时的方式或按相反的顺序,来执行功能,这应被本发明的实施例所属技术领域的技术人员所理解。

在流程图中表示或在此以其他方式描述的逻辑和/或步骤,例如,可以被认为是用于实现逻辑功能的可执行指令的定序列表,可以具体实现在任何计算机可读介质中,以供指令执行系统、装置或设备(如基于计算机的系统、包括处理器的系统或其他可以从指令执行系统、装置或设备取指令并执行指令的系统)使用,或结合这些指令执行系统、装置或设备而使用。就本说明书而言,"计算机可读介质"可以是任何可以包含、存储、通信、传播或传输程序以供指令执行系统、装置或设备或结合这些指令执行系统、装置或设备而使用的装置。计算机可读介质的更具体的示例(非穷尽性列表)包括以下:具有一个或n个布线的电连接部(电子装置),便携式计算机盘盒(磁装置),随机存取存储器(ram),只读存储器(rom),可擦除可编辑只读存储器(eprom或闪速存储器),光纤装置,以及便携式光盘只读存储器(cdrom)。另外,计算机可读介质甚至可以是可在其上打印所述程序的纸或其他合适的介质,因为可以例如通过对纸或其他介质进行光学扫描,接着进行编辑、解译或必要时以其他合适方式进行处理来以电子方式获得所述程序,然后将其存储在计算机存储器中。

应当理解,本发明的各部分可以用硬件、软件、固件或它们的组合来实现。在上述实施方式中,n个步骤或方法可以用存储在存储器中且由合适的指令执行系统执行的软件或固件来实现。如,如果用硬件来实现和在另一实施方式中一样,可用本领域公知的下列技术中的任一项或他们的组合来实现:具有用于对数据信号实现逻辑功能的逻辑门电路的离散逻辑电路,具有合适的组合逻辑门电路的专用集成电路,可编程门阵列(pga),现场可编程门阵列(fpga)等。

本技术领域的普通技术人员可以理解实现上述实施例方法携带的全部或部分步骤是可以通过程序来指令相关的硬件完成,所述的程序可以存储于一种计算机可读存储介质中,该程序在执行时,包括方法实施例的步骤之一或其组合。

此外,在本发明各个实施例中的各功能单元可以集成在一个处理模块中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个模块中。上述集成的模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。所述集成的模块如果以软件功能模块的形式实现并作为独立的产品销售或使用时,也可以存储在一个计算机可读取存储介质中。

上述提到的存储介质可以是只读存储器,磁盘或光盘等。尽管上面已经示出和描述了本发明的实施例,可以理解的是,上述实施例是示例性的,不能理解为对本发明的限制,本领域的普通技术人员在本发明的范围内可以对上述实施例进行变化、修改、替换和变型。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1