本发明属于深度学习的数据分析工作流领域,尤其涉及一种基于ray的云边协同数据分析工作流优化方法及系统。
背景技术:
1、本部分的陈述仅仅是提供了与本发明相关的背景技术信息,不必然构成在先技术。
2、随着数字经济的快速发展,边缘业务规模化落地,行业数据出现爆炸性增长,云边协同的数据计算模式成为智能化转型的突破口,在教育、医疗、交通、海洋等领域广泛应用。目前,模型发展趋于集成模型、大模型,需要依赖于云端分布式训练,同时训练后的大模型通常难以直接应用于下游任务。
3、首先在模型研发阶段,将终端设备设计和测试好的模型置于云端分布式集群中进行训练时,面临的主要挑战是需要重写代码,以便支持分布式存储或结构。通常通过本地编辑器实现代码编辑,训练阶段远程连接云端服务器,实现分布式模型训练。但算法人员需要了解分布式程序的编写方式,无疑增加了算法工程师的工作负担;其次,在服务部署阶段,云端训练后的模型通常模型规模较大,难以直接应用部署于资源有限的边缘端及下游任务。模型压缩在促进机器学习应用程序的高效、部署方面的优势日益突显。其中知识蒸馏(knowledge distillation)作为模型压缩方法的一种在计算机视觉和自然语言处理任务中应用广泛,其工作原理通常是利用参数规模较大,精度更高的大模型作为成“教师”网络,以结构更加轻量,模型参数量较小的模型作为学生模型,实现知识传递。
4、目前在知识蒸馏领域通常存在以下痛点:1)教师模型难以动态掌握学生模型的特征提取能力。通常传统知识蒸馏方法中,学生模型只能被动地从教师模型接受知识,而不考虑学生模型的学习能力和表现。2)教师模型在训练过程中通常参数冻结,无法更新。通常在训练时,教师模型的参数处于冻结状态,难以根据学生模型的状态,做出调整及更新。
技术实现思路
1、为了解决上述背景技术中存在的模型研发阶段从终端设备设计到云端分布式训练的操作复杂问题以及模型服务部署阶段模型压缩问题,本发明提供一种基于ray的云边协同数据分析工作流优化方法及系统,其首先在模型研发阶段,采用分布式应用程序框架ray实现终端设备到云端的无缝扩展,实现模型分布式训练的快速迭代;其次在服务部署阶段,采用知识蒸馏方法实现模型压缩,降低模型参数体积,更便于服务下游任务。
2、为了实现上述目的,本发明采用如下技术方案:
3、本发明的第一个方面提供一种基于ray的云边协同数据分析工作流优化方法。
4、基于ray的云边协同数据分析工作流优化方法,包括:
5、获取数据集,基于ray,采用数据集在终端设备上训练第一学生模型和第一教师模型,得到蒸馏训练方法;
6、将数据集和蒸馏训练方法上传至云端,设置并行工作器的数量和超参数,采用数据集和蒸馏训练方法在云端上训练第二学生模型和第二教师模型,得到训练好的第二学生模型,并将训练好的第二学生模型部署到终端;
7、所述蒸馏训练方法包括:在训练的过程中,将学生模型复制为次学生模型,引入解耦损失函数,更新次学生模型;在测试过程中,以次学生模型的损失结果为反馈信号,对教师模型进行元更新;将经过元更新的教师模型在相同训练批次对学生模型进行参数传递,经过迭代知识传递后,得到蒸馏后的学生模型。
8、进一步地,所述解耦损失函数为:
9、decoupled loss=αtckd+pnckd
10、其中,tckd为目标类解耦损失函数,nckd为非目标类解耦损失函数,所述α、β均为超参数。
11、进一步地,通过计算二阶导数和执行梯度更新操作对所述教师模型进行元更新。
12、进一步地,所述基于ray,采用数据集在终端设备上训练第一学生模型和第一教师模型的过程包括:
13、将数据集分为训练集和验证集,并加载到ray数据集接口中;
14、根据第一学生模型和第一教师模型,在终端设备配置ray形式的模型训练方法,使用ray train库中的torchtrainer组件,通过trainer来管理训练过程,在torchtrainer中定义第一学生模型和第一教师模型的训练函数以及设置参数,将并行工作器的数量设置为1后,采用训练集在终端设备上对第一学生模型和第一教师模型进行训练操作;
15、采用验证集对第一学生模型进行验证。
16、进一步地,所述将数据集和蒸馏训练方法上传至云端,设置并行工作器的数量和超参数,采用数据集和蒸馏训练方法在云端上训练第二学生模型和第二教师模型的过程包括:
17、将述将数据集和蒸馏训练方法上传至云端,将数据集分为训练集和验证集,并加载到ray数据集接口中;
18、根据第二学生模型和第二教师模型,在云端配置ray形式的模型训练方法,使用ray train库中的torchtrainer组件,通过trainer来管理训练过程,在torchtrainer中定义第二学生模型和第二教师模型的训练函数以及设置参数,将torchtrainer中并行工作器的数量进行修改,采用训练集在终端设备上对第二学生模型和第二教师模型进行训练操作;
19、采用验证集对第二学生模型进行验证。
20、本发明的第二个方面提供一种基于ray的云边协同数据分析工作流优化系统。
21、基于ray的云边协同数据分析工作流优化系统,包括:终端设备、云端和终端,所述云端均与终端设备和终端连接;
22、所述终端设备,用于基于ray,采用数据集训练第一学生模型和第一教师模型,得到蒸馏训练方法,并将数据集和蒸馏训练方法上传至云端;
23、所述云端,用于设置并行工作器的数量和超参数,采用数据集和蒸馏训练方法训练第二学生模型和第二教师模型,得到训练好的第二学生模型,并将训练好的第二学生模型部署到终端;
24、其中,所述蒸馏训练方法包括:在训练的过程中,将学生模型复制为次学生模型,引入解耦损失函数,更新次学生模型;在测试过程中,以次学生模型的损失结果为反馈信号,对教师模型进行元更新;将经过元更新的教师模型在相同训练批次对学生模型进行参数传递,经过迭代知识传递后,得到蒸馏后的学生模型。
25、进一步地,所述解耦损失函数为:
26、decoupled loss=αtckd+pnckd
27、其中,tckd为目标类解耦损失函数,nckd为非目标类解耦损失函数,所述α、β均为超参数。
28、进一步地,通过计算二阶导数和执行梯度更新操作对所述教师模型进行元更新。
29、进一步地,所述基于ray,采用数据集在终端设备上训练第一学生模型和第一教师模型的过程包括:
30、将数据集分为训练集和验证集,并加载到ray数据集接口中;
31、根据第一学生模型和第一教师模型,在终端设备配置ray形式的模型训练方法,使用ray train库中的torchtrainer组件,通过trainer来管理训练过程,在torchtrainer中定义第一学生模型和第一教师模型的训练函数以及设置参数,将并行工作器的数量设置为1后,采用训练集在终端设备上对第一学生模型和第一教师模型进行训练操作;
32、采用验证集对第一学生模型进行验证。
33、进一步地,所述将数据集和蒸馏训练方法上传至云端,设置并行工作器的数量和超参数,采用数据集和蒸馏训练方法在云端上训练第二学生模型和第二教师模型的过程包括:
34、将述将数据集和蒸馏训练方法上传至云端,将数据集分为训练集和验证集,并加载到ray数据集接口中;
35、根据第二学生模型和第二教师模型,在云端配置ray形式的模型训练方法,使用ray train库中的torchtrainer组件,通过trainer来管理训练过程,在torchtrainer中定义第二学生模型和第二教师模型的训练函数以及设置参数,将torchtrainer中并行工作器的数量进行修改,采用训练集在终端设备上对第二学生模型和第二教师模型进行训练操作;
36、采用验证集对第二学生模型进行验证。
37、与现有技术相比,本发明的有益效果是:
38、为了解决目前将终端设备的模型训练代码扩展到云端,通常需要对代码进行较大调整,并需要掌握分布式训练方法,门槛高等问题,本发明利用ray集成分布式并行、内存共享等方法,实现从终端设备到云端的无缝扩展机制,优化适配器,实现终端设备的模型训练在几乎不需要修改代码的情况下扩展到云端,从而提高模型优化效率,降低技术门槛。
39、为了解决云端训练大模型无法直接部署到资源受限的边缘端的问题,本发明首先利用基于元学习的知识蒸馏,通过学生模型的反馈提高教师模型的“教学能力”,使得知识可以更好的从教师模型转移到学生模型。同时优化了蒸馏损失,将解耦损失的思想引入其中,将传统知识蒸馏损失函数解耦为目标类知识蒸馏(target class knowledgedistillation,tckd)和非目标类知识蒸馏(non-target class knowledge distillation,nckd)。其中tckd传递了关于训练样本“难度”的知识,用于描述识别每个训练样本的难易程度;nckd传递了非目标逻辑之间的知识,更突显非目标(暗目标)的知识。通常两者处于耦合状态是相互制约关系,经过实验发现nckd是蒸馏涨点的主要原因,但是它处于抑制状态,因此修改传统知识蒸馏损失函数引入两个超参数将两部分解耦从而实现知识蒸馏的真正涨点,提高了学习模型的精度以及准确率,以使蒸馏后的学生模型可以部署到边缘端。