本公开涉及强化学习领域,具体涉及一种基于视觉表征的单智能体强化学习模型的训练方法、装置、电子设备、存储介质和计算机程序产品。
背景技术:
1、随着计算能力的显著提升和算法创新的不断涌现,强化学习(reinforcementlearning,rl)在众多领域取得了显著进展,包括游戏、机器人、自动驾驶等。尽管已经取得辉煌的成就,rl在实际应用中仍面临一个核心挑战,即如何提高样本效率。
2、智能体需要具有从有限的交互中快速学习到有效策略的能力,尤其是在处理高维视觉图像并执行复杂连续控制任务时,这一挑战显得尤为突出。视觉信号的高维度、冗余性以及多样性给基于视觉的强化学习的应用带来了一系列挑战,例如视觉强化学习普遍面临样本效率低、控制性能差的问题。
技术实现思路
1、本公开示例性实施例提供的基于视觉表征的单智能体强化学习模型的训练方法、装置、电子设备、存储介质和计算机程序产品,可以至少解决上述技术问题和上文未提及的其它技术问题。
2、根据本公开的一个方面,提供一种基于视觉表征的单智能体强化学习模型的训练方法,所述基于视觉表征的单智能体强化学习模型包括在线状态编码器、动作编码器、强化学习网络和辅助任务网络,所述辅助任务网络包括状态预测模型,所述基于视觉表征的单智能体强化学习模型的训练方法包括:获取目标智能体当前时间段的状态信息、动作信息和奖赏信息,其中,所述当前时间段由包含当前时刻在内的预设多个连续的时刻组成,所述状态信息和所述动作信息是基于针对所述目标智能体的观测图像而得到的;将所述状态信息输入到所述在线状态编码器,得到状态特征;将所述动作信息输入到所述动作编码器,得到动作特征;将所述状态特征、所述动作特征和所述奖赏信息输入到所述状态预测模型,得到所述目标智能体下一时间段的状态预测特征,其中,所述下一时间段由包含下一时刻在内的预设多个连续的时刻组成;基于所述状态预测特征和对应真实值之间的差异,计算状态预测损失;将所述状态特征和所述动作特征输入到所述强化学习网络,以计算强化学习损失;基于所述强化学习损失和所述状态预测损失,对所述基于视觉表征的单智能体强化学习模型进行训练。
3、可选的,所述将所述状态特征、所述动作特征和所述奖赏信息输入到所述状态预测模型,得到所述目标智能体下一时间段的状态预测特征,包括:将所述状态特征、所述动作特征和所述奖赏信息输入到所述状态预测模型,并对所述状态预测模型的输出求平均值,以得到所述目标智能体下一时间段的初始状态预测特征;将所述初始状态预测特征输入到在线投影网络,得到所述状态预测特征。
4、可选的,所述在线投影网络包括在线投影头和在线预测头;其中,所述将所述初始状态预测特征输入到在线投影网络,得到所述状态预测特征,包括:将所述初始状态预测特征输入到所述在线投影头,得到第一投影数据;将所述第一投影数据输入到所述在线预测头,得到所述状态预测特征;其中,所述基于所述状态预测特征和对应真实值之间的差异,计算状态预测损失,包括:获取所述目标智能体下一时间段的状态信息;将所述下一时间段的状态信息输入到目标状态编码器,得到第二状态预测特征,其中,所述目标状态编码器的参数基于所述在线状态编码器当前的参数以及预设衰减率通过指数移动平均而得到;将所述第二状态预测特征输入到目标投影网络,得到所述状态预测特征的对应真实值,其中,所述目标投影网络包括目标投影头,所述目标投影头的参数基于所述在线投影头当前的参数通过指数移动平均而得到;基于所述状态预测特征和对应真实值之间的差异,计算所述状态预测损失。
5、可选的,所述辅助任务网络还包括动作预测模型;其中,所述基于视觉表征的单智能体强化学习模型的训练方法还包括:获取所述目标智能体当前时刻的状态信息和下一时刻的状态信息;将所述当前时刻的状态信息和所述下一时刻的状态信息分别输入到所述在线状态编码器,得到当前时刻的状态特征和下一时刻的状态特征;将所述当前时刻的状态特征和所述下一时刻的状态特征输入到所述动作预测模型,得到当前时刻的动作预测特征;基于所述动作预测特征和对应真实值之间的差异,计算动作预测损失;其中,所述基于所述强化学习损失和所述状态预测损失,对所述基于视觉表征的单智能体强化学习模型进行训练,包括:基于所述强化学习损失、所述状态预测损失和所述动作预测损失,对所述基于视觉表征的单智能体强化学习模型进行训练。
6、可选的,所述辅助任务网络还包括奖赏预测模型;其中,所述基于视觉表征的单智能体强化学习模型的训练方法还包括:获取所述目标智能体当前时刻的状态信息和当前时间段的动作信息;将所述当前时间段的动作信息输入到动作编码器,得到当前时间段的动作特征;将所述当前时刻的状态特征和所述当前时间段的动作特征输入到奖赏预测模型,得到所述当前时间段中最后一个时刻的奖赏预测特征;基于所述奖赏预测特征和对应真实值之间的差异,计算奖赏预测损失;其中,所述基于所述强化学习损失、所述状态预测损失和所述动作预测损失,对所述基于视觉表征的单智能体强化学习模型进行训练,包括:基于所述强化学习损失、所述状态预测损失、所述动作预测损失和所述奖赏预测损失,对所述基于视觉表征的单智能体强化学习模型进行训练。
7、可选的,所述将所述状态特征、所述动作特征和所述奖赏信息输入到所述状态预测模型,包括:将所述状态特征、所述动作特征、所述奖赏信息和空间位置信息输入到所述状态预测模型,其中,同一时刻的所述状态特征、所述动作特征和所述奖赏信息共享相同的所述空间位置信息。
8、可选的,所述状态预测模型为预设多个相同块构成的transformer模型,每个块包含一个多头自注意力层和一个由多层感知机构成的前馈网络层,每个层之前包括层归一化组件,每个层之后包括残差连接。
9、可选的,所述方法还包括:在所述获取目标智能体当前时间段的状态信息之后,对所述状态信息中预设比例的数据执行随机掩码操作,以得到更新后的状态信息。
10、根据本公开的另一方面,还提供一种基于视觉表征的单智能体强化学习模型的训练装置,所述基于视觉表征的单智能体强化学习模型包括在线状态编码器、动作编码器、强化学习网络和辅助任务网络,所述辅助任务网络包括状态预测模型,所述基于视觉表征的单智能体强化学习模型的训练装置包括:
11、信息获取单元,被配置为:获取目标智能体当前时间段的状态信息、动作信息和奖赏信息,其中,所述当前时间段由包含当前时刻在内的预设多个连续的时刻组成,所述状态信息和所述动作信息是基于针对所述目标智能体的观测图像而得到的;状态表征单元,被配置为:将所述状态信息输入到所述在线状态编码器,得到状态特征;动作表征单元,被配置为:将所述动作信息输入到所述动作编码器,得到动作特征;状态预测单元,被配置为:将所述状态特征、所述动作特征和所述奖赏信息输入到所述状态预测模型,得到所述目标智能体下一时间段的状态预测特征,其中,所述下一时间段由包含下一时刻在内的预设多个连续的时刻组成;状态损失计算单元,被配置为:基于所述状态预测特征和对应真实值之间的差异,计算状态预测损失;强化学习损失计算单元,被配置为:将所述状态特征和所述动作特征输入到所述强化学习网络,以计算强化学习损失;模型训练单元,被配置为:基于所述强化学习损失和所述状态预测损失,对所述基于视觉表征的单智能体强化学习模型进行训练。
12、根据本公开实施例的另一方面,还提供一种电子设备,包括:至少一个处理器;至少一个存储计算机可执行指令的存储器,其中,所述计算机可执行指令在被所述至少一个处理器运行时,促使所述至少一个处理器执行如上任一所述的基于视觉表征的单智能体强化学习模型的训练方法。
13、根据本公开实施例的另一方面,还提供一种存储指令的计算机可读存储介质,当所述指令被至少一个处理器运行时,促使所述至少一个处理器执行如上任一所述的基于视觉表征的单智能体强化学习模型的训练方法。
14、根据本公开实施例的另一方面,还提供一种包括至少一个计算装置和至少一个存储指令的存储装置的系统,其中,所述指令在被所述至少一个计算装置运行时,促使所述至少一个计算装置执行如上任一所述的基于视觉表征的单智能体强化学习模型的训练方法。
15、根据本公开实施例的另一方面,还提供一种计算机程序产品,包括计算机程序/指令,所述计算机程序/指令被处理器执行时实现如上任意一项所述的基于视觉表征的单智能体强化学习模型的训练方法。
16、本公开实施例提供的技术方案至少带来以下有益效果:
17、根据本公开的基于视觉表征的单智能体强化学习模型的训练方法、装置、电子设备、存储介质和计算机程序产品,针对目标智能体,能够通过辅助任务网络从视觉表征的角度出发,学习目标智能体的状态表征和动作表征,通过强化学习网络为目标智能体选择最佳决策动作,并且,充分利用强化学习中时间段的时序信息,可以实现单智能体在具有挑战的以图像作为状态输入的复杂连续控制任务中的性能和样本效率提升。
18、另外,采用一种非对称的投影网络架构,可以避免模型在自监督学习过程中的崩溃问题。
19、另外,引入动作预测学习作为额外的学习约束,可以增强状态表征在未来动作预测中的贡献。
20、另外,额外添加奖赏预测学习用于约束状态和动作表征,可以促进智能体更好地理解其行为可能产生的后果。
21、另外,transformer架构使得可以同时处理目标智能体一段时间序列的信息,以充分利用强化学习中时间段的时序信息;每个层的前面均配备层归一化组件,可以提高模型训练过程中的稳定性和加快收敛速度;每个层的输出都加上残差连接,可以促进更深层次网络的有效训练,且有助于防止梯度消失或爆炸问题,保证信息在网络中的顺畅流动。