一种基于跨图像像素空间关系的扩散模型蒸馏方法

文档序号:39278085发布日期:2024-09-06 00:52阅读:12来源:国知局
一种基于跨图像像素空间关系的扩散模型蒸馏方法

本发明涉及计算机视觉领域,具体来说涉及用于生成图像的扩散模型领域,更具体地说,涉及一种基于跨图像像素空间关系的扩散模型蒸馏方法。


背景技术:

1、图像生成是一个视觉上的基础任务,其目的是训练一个模型,使其仅输入一个随机采样的噪声图像便可以生成出有意义的现实图像。图像生成在现实创作、图像编辑以及图像风格化领域具有巨大的应用价值。目前最先进的图像生成模型为扩散模型,其能够通过迭代去噪过程最终生成极高质量的图像。但其迭代采样的去噪生成过程导致其推理速度较慢,限制了其在现实世界场景和资源限制的边缘设备上的部署应用。而知识蒸馏作为一项模型压缩技术,可以利用教师教授学生的模式来提升学生网络的性能,教师模型通常是复杂度高但是性能优良的网络,学生模型通常是复杂度低但是性能不足的网络。示意性的,知识蒸馏的大致过程为:

2、1)教师模型的建立:需要一个经过充分训练的扩散模型作为教师模型。这个模型在生成图像时表现优异,但可能需要较多的迭代步数(假设步数为a)来逐步去噪,生成高质量的图像;

3、2)学生模型的建立:设计一个初始参数与教师模型的参数相同的学生模型,目标是让学生模型学习并模仿教师模型的生成过程,但以更少的迭代步数(假设部署为b)完成,b<a;

4、3)知识传递:在知识蒸馏过程中,通过设计合适的损失函数,学生模型被训练以模仿教师模型生成图像;

5、4)将经步骤1)-3)训练的学生模型作为新的教师模型,返回步骤1),直至学生模型的迭代步数降低至预设步数。

6、尽管知识蒸馏技术的应用为扩散模型的优化提供了新思路,但现有的知识蒸馏方法在在损失的设计上仍存在局限。

7、知识蒸馏算法的核心是定义一种有意义的知识形式,然后将该知识从教师模型传到学生模型,但现有的关于扩散模型知识蒸馏的算法往往关注如何设计更好的采样方式却忽略如何设计有意义的知识形式这一同样重要的问题。最早的扩散模型蒸馏算法在对教师和学生模型生成的图像进行对齐时,仅考虑从像素空间对生成的图像进行均方误差计算,并未意识到图像特征对齐相较于图像像素对齐的优越性。后来的蒸馏方法发现了特征对齐相较于像素对齐的优越性,有的方法直接在特征空间中计算均方误差损失,有的则直接利用图像特征获得特征分布以进行kl散度计算。这些基于特征的对齐方法在算法效果上大大超越了基于像素的对齐方法。

8、作为对比,发明人检索到的知识蒸馏技术的损失的设计包括以下3种:

9、设计1:采用两个模型生成的图像来构建损失,考虑最小化教师模型和学生模型生成的图像间的偏差,目标是对两者生成的图像进行对齐。示意性的设计示例可参见以下技术文献[1]、[2]。但该设计仅考虑在像素空间内计算均方误差损失,却忽略了图像特征这一更有价值的信息。

10、设计2:对两个模型生成的图像提取图像特征,考虑在特征空间内计算学生模型的图像特征与教师模型的图像特征之间偏差作为损失(如:均方误差损失)。示意性的设计示例可参见以下技术文献[3]。

11、设计3:对两个模型生成的图像提取图像特征,考虑利用两个图像特征的特征分布进行kl散度损失计算,从而在特征维度对图像信息进行对齐。示意性的设计示例可参见以下技术文献[4]。

12、相比于设计1,按设计2和设计3的技术蒸馏得到的学生模型的性能更优,但在蒸馏至较小迭代步数的时候,性能仍有待进一步提升。

13、需要说明的是:本背景技术仅用于介绍本发明的相关信息,以便于帮助理解本发明的技术方案,但并不意味着相关信息必然是现有技术。相关信息与本发明方案一同提交和公开,在没有证据表明相关信息已在本发明的申请日以前公开的情况下,相关信息不应被视为现有技术。

14、技术文献:

15、[1]曹巍瀚,张一帆.基于扩散模型的图像生成模型压缩和加速方法及系统[p].江苏省:cn202310823847.9,2023-09-01.

16、[2]salimans t,ho j.progressive distillation for fast sampling ofdiffusion models[j].arxiv preprint arxiv:2202.00512,2022.

17、[3]song y,dhariwal p,chen m,et al.consistency models[c]//international conference on machine learning.pmlr,2023:32211-32252.

18、[4]sun w,chen d,wang c,et al.accelerating diffusion sampling withclassifier-based feature distillation[c]//2023ieee international conferenceon multimedia and expo(icme).ieee,2023:810-815.

19、[5]krizhevsky a,hinton g.learning multiple layers of features fromtiny images[j].2009.

20、[6]song j,meng c,ermon s.denoising diffusion implicit models[j].arxivpreprint arxiv:2010.02502,2020.

21、[7]heusel m,ramsauer h,unterthiner t,et al.gans trained by a twotime-scale update rule converge to a local nash equilibrium[j].advances inneural information processing systems,2017,30.

22、[8]salimans t,goodfellow i,zaremba w,et al.improved techniques fortraining gans[j].advances in neural information processing systems,2016,29.

23、[9]huang g,liu z,van der maaten l,et al.densely connectedconvolutional networks[c]//proceedings of the ieee conference on computervision and pattern recognition.2017:4700-4708.


技术实现思路

1、因此,本发明的目的在于克服上述现有技术的缺陷,提供一种基于跨图像像素空间关系的扩散模型蒸馏方法。

2、本发明的目的是通过以下技术方案实现的:

3、根据本发明的第一方面,提供一种基于跨图像像素空间关系的扩散模型蒸馏方法,包括:获取教师模型和学生模型,两个模型均为生成式的扩散模型,教师模型是预先训练好的模型,学生模型所设置的迭代去噪的步数比教师模型更少;获取一个批次的样本数据,将每个样本数据分别输入两个模型以生成该样本数据对应的教师图像和学生图像;将当前批次生成的每张教师图像和学生图像分别输入预训练好的特征提取器,得到教师图像特征和学生图像特征;根据每两个教师图像特征计算两者像素间的第一空间关系矩阵,以及根据每两个学生图像特征计算两者像素间的第二空间关系矩阵;根据每对第一空间关系矩阵和第二空间关系矩阵间的kl散度求损失,根据损失更新学生模型的参数,其中,当一个第一空间关系矩阵和另一个第二空间关系矩阵所涉及的样本数据的编号一致时视两者为一对。

4、可选的,第一空间关系矩阵和第二空间关系矩阵中的任意空间关系矩阵是根据一个图像特征与另一图像特征的转置进行矩阵乘法得到的。

5、可选的,按照以下方式得到第一空间关系矩阵和第二空间关系矩阵:

6、

7、其中,表示涉及第i个样本数据和第j个样本数据的第一空间关系矩阵,表示根据第i个样本数据生成的教师图像所对应的教师图像特征,表示根据第j个样本数据生成的教师图像所对应的教师图像特征,表示涉及第i个样本数据和第j个样本数据的第二空间关系矩阵,表示根据第i个样本数据生成的学生图像所对应的学生图像特征,表示根据第j个样本数据生成的学生图像所对应的学生图像特征,t表示转置,表示矩阵的维度为a行a列。

8、可选的,kl散度按照以下方式确定:

9、

10、其中,表示涉及第i个样本数据和第j个样本数据的那对空间关系矩阵间的kl散度,a表示空间关系矩阵的总行数,a表示空间关系矩阵的第a行,表示第一空间关系矩阵的第a行向量,表示第二空间关系矩阵的第a行向量,σ(·)表示softmax函数,τ表示预设的温度系数,kl(,)表示kl散度函数。

11、可选的,损失按照以下方式确定:获得当前批次的各对第一空间关系矩阵和第二空间关系矩阵间的kl散度之和,并求均值,得到损失。

12、可选的,方法还包括:分阶段地降低学生模型的迭代去噪的步数,最终得到的学生模型所设置的迭代去噪的步数为1、2、3、4、5、6或者7。

13、可选的,更新学生模型的参数包括:根据损失和当前的学生模型的参数计算梯度;基于随机梯度下降算法,根据计算出的梯度来调整学生模型的可学习参数,以期达到损失最小化的目标。

14、模型的可学习参数,以期达到损失最小化的目标。

15、根据本发明的第二方面,提供一种图像生成方法,包括:获取用于生成图像的数据;获取按照第一方面所述方法更新参数所得到的经训练的学生模型;将用于生成图像的数据输入经训练的学生模型,以生成图像。

16、根据本发明的第三方面,提供一种电子设备,包括:一个或多个处理器;以及存储器,其中存储器用于存储可执行指令;所述一个或多个处理器被配置为经由执行所述可执行指令以实现第一和/或第二方面所述方法的步骤。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1