【知识蒸馏论文解读】Dataset Distillation 创新性十足的数据集蒸馏


写在前面

最近对数据集蒸馏比较感兴趣,抽时间看了下这篇经典的数据蒸馏论文《Dataset Distillation》,它是属于knowledge distillation领域的工作。

一、动机

本文提出的方法是数据集蒸馏(Dataset Distillation)

  1. 从大的训练数据中蒸馏知识到小的数据集
  2. 小的数据集不需要与原始的大的训练数据分布相同
  3. 只要在小的数据集上训练几步梯度下降就能达到和原始数据相近的模型效果

    This paper presented dataset distillation for compressing the knowledge of entire training data into a few synthetic training images. We can train a network to reach high performance with a small number of distilled images and several gradient descent steps.

模型蒸馏(model层面)的目标是从一个复杂的模型中蒸馏知识到小的模型上。

本文考虑的是数据集上的蒸馏(dataset层面),具体来说,我们会固定住模型,然后尝试从较大的训练数据集中蒸馏知识到小的数据集上。

核心目的是将原始的大数据集压缩成一个小的数据集(不需要来自训练集的分布),并且在这个小数据集上训练模型的效果和原始较大数据集上的训练效果是接近的。

例如,可以用60000张MNIST训练集蒸馏出10张distilled的图片(每个类一张),在这10张图片只需再训练几轮,就能到达和原始效果接近的效果。

实验效果图:

(a)表明只用distilled images也能将网络模型训练的很好(image classification)
(b)用SVHN数据集训练好的模型迁移到MNIST上效果只有54%,但是当我们用distilled images做下微调,这个模型在MNIST上的准确率就可以达到85%(image classification)
(c)表明用distilled images可以用于攻击已经训练好的网络模型 (poisoning attack)。

二、背景知识

2015 Hinton等人提出了network distillation(model compression),本文我们不蒸馏模型,我们蒸馏数据集。

In this paper, we are considering a related but orthogonal task: rather than distilling the model, we propose to distill the dataset. Unlike network distillation, we keep the model fixed but encapsulate the knowledge of the entire training dataset, which typically contains thousands to millions of images, into a small number of synthetic training images

通常来说如果你小数据的分布和真正测试集的分布不同,是很难训练出一个好的模型的,但是本文的工作表明,这完全是可能的。

本文提出了一种新的优化方法,尽管现在只有很小的数据集,但他不仅能够抓住原始大数据集的信息,而且只要几步梯度下降就能训练好模型,并且在真正的测试集上效果良好。

相关工作:简单列一下相关的方向

  • Knowledge distillation
  • Dataset pruning,core-set construction,and instance selection
  • Gradient-based hyperparameter optimization
  • Understanding datasets

三、方法

3.1 训练$\tilde{\mathbf{x}}$和$\tilde \eta$的过程

传统的模型训练会使用随机梯度下降进行参数优化,假设现在进行第t次参数更新,使用的minibatch的训练集为$\mathbf{x}{t}=\left{x{t, j}\right}_{j=1}^{n}$

$$\theta_{t+1}=\theta_{t}-\eta \nabla_{\theta_{t}} \ell\left(\mathbf{x}{t}, \theta{t}\right)$$

通常来说这种训练方式,需要更新上万次参数才能收敛。

本文的目标是学习到一小部分的合成的distilled的训练集$\tilde{\mathbf{x}}=\left{\tilde{x}{i}\right}{i=1}^{M}$,其中M远小于总的训练集数量N,以及学习对应的学习率$\tilde \eta$,使得只要一次参数更新就能得到一个在真实测试集上效果很好的模型参数。(为什么要学习$\tilde \eta$?因为作者想要通过少量的梯度下降便得到一个比较好的模型,因此学习率既不能太大,也不能太小,因而需要学习获得)

$$\theta_{1}=\theta_{0}-\color{red}\tilde{\eta} \color{black} \nabla_{\theta_{0}} \ell\left(\color{red} \tilde{\mathbf{x}}\color{black} , \theta_{0}\right)$$

那么如何学习$\tilde{\mathbf{x}}$和$\tilde \eta$呢?

很简单,我们希望通过$\tilde{\mathbf{x}}$和$\tilde \eta$得到的$\theta_1$,能够使$\ell\left(\mathbf{x}, \theta_{1}\right)$最小,其中$x$是原始的大训练集。

因此对应的优化目标如下:

$$
\begin{aligned}
\tilde{\mathbf{x}}^{}, \tilde{\eta}^{} &= \underset{\tilde{\mathbf{x}}, \tilde{\eta}}{\arg \min } \mathcal{L}\left(\tilde{\mathbf{x}}, \tilde{\eta} ; \theta_{0}\right)\ &=\underset{\tilde{\mathbf{x}}, \tilde{\eta}}{\arg \min } \ell\left(\mathbf{x}, \theta_{1}\right) \
&=\underset{\tilde{\mathbf{x}}, \tilde{\eta}}{\operatorname{\arg \min}} \ell\left(\mathbf{x}, \theta_{0}-\color{red}\tilde{\eta} \color{black} \nabla_{\theta_{0}} \ell\left(\color{red} \tilde{\mathbf{x}}\color{black}, \theta_{0}\right)\right)
\end{aligned}
$$

说明1:在这种优化方式下,我们学习的参数只有$\tilde{\mathbf{x}}$和$\tilde \eta$,通过随机梯度下降求解$\tilde{\mathbf{x}}$和$\tilde \eta$的过程和普通的优化没有什么区别,也是要进行上万次更新。

说明2:对于数据$\tilde{x}$的其它部分,例如标签,只把它固定而不进行学习。


3.2 在合成数据集$\tilde{\mathbf{x}}$上训练模型(fix init)

当我们训练好得到合成数据集$\tilde{\mathbf{x}}$和对应的学习率$\tilde \eta$后,我们就可以在这个合成数据集$\tilde{\mathbf{x}}$上训练模型了。

那么这个模型的初始化参数应该是什么呢?

作者发现,这时初始化的参数需要和得到合成数据$\tilde x$的初始化参数$\theta_0$一样(fixed initialization),也就是说通过上述方法得到的合成数据并不能很好地泛化到其他初始参数,这也很好理解,因为优化合成数据的时候不仅用了原始数据,还用了固定的初始化参数$\theta_0$。

3.3 训练$\tilde{\mathbf{x}}$和$\tilde \eta$时用随机初始化参数$\theta_0$(random init)

为了解决上面提到的问题,作者提出,训练$\tilde{\mathbf{x}}$和$\tilde \eta$的时候不用固定的$\theta_0$,而是每次从分布$p(\theta_0)$中随机采样。

此时的优化目标如下:
$$
\tilde{\mathbf{x}}^{}, \tilde{\eta}^{}=\underset{\tilde{\mathbf{x}}, \tilde{\eta}}{\arg \min } \mathbb{E}{\theta{0} \sim p\left(\theta_{0}\right)} \mathcal{L}\left(\tilde{\mathbf{x}}, \tilde{\eta} ; \theta_{0}\right)
$$

算法伪代码:

最重要的是理解第6,7行是先在distilled images $\tilde x$上得到参数$\theta_1$(学习率为$\tilde \eta$),然后用这个网络参数$\theta_1$在真实的数据$x$上去算loss。第9行使用标准的梯度下降对$\tilde x$和$\tilde \eta$进行参数更新

实验表明,在随机初始化参数$\theta_0$的条件下得到的合成数据集$\tilde x$后,我们在合成数据集$\tilde x$训练模型时可以随机初始化参数,而且效果也不错(不过还是没有固定$\theta_0$的效果好),另外在random initialization条件下得到distilled images通常包含有一定的信息,因为合成数据编码了每个类别的判别特征,如实验部分Figure3所示。

3.4 多步参数更新

前面介绍的从$\theta_0$到$\theta_1$我们只进行了signle GD step,实际上这一部分可以改成多步。只需要将Algorithm1中第6行改成多步即可。
$$
\theta_{i+1}=\theta_{i}-\tilde{\eta}{i} \nabla{\theta_{i}} \ell\left(\tilde{\mathbf{x}}{i}, \theta{i}\right)
$$

每一步使用不同的distilled data $\tilde x_i$和$\tilde \eta_i$

文章中还用了优化算法来加快梯度回传的过程。

3.5 不同的初始化参数$\theta_0$的方式

除了固定初始化模型参数和随机初始化模型参数外,文章还提出了使用其他任务中预训练好的模型参数来模型的参数,所以共有以下4种方式构建初始化参数,其中最后一种方式的效果是最好的。

四、实验结果

数据集:MNIST、CIFAR10

Fixed initialization得到的distilled images:


固定初始化模型参数条件下得到的最终模型在真实测试集上效果更好,但合成的图片比较模糊(论文在3.2给出了解释:主要是因为它除了编码了原始训练集$x$,还编码了一个固定的参数$\theta_0$,这个参数$\theta_0$就像给图片加了random noise,所以看起来模糊)。

Random initialization得到的distilled images:


随机初始化模型参数条件下得到的最终模型在真实测试集上效果还不错,有更好的泛化能力(对没见过的初始化参数也有不错的效果),合成的图片也包含更多的特征,如数字比较清楚。

五、本文总结

In this paper, we have presented dataset distillation for compressing the knowledge of entire training data into a few synthetic training images. We can train a network to reach high performance with a small number of distilled images and several gradient descent steps

未来工作:

  1. 将数据集蒸馏应用到大规模的图片数据集(ImageNet)以及其他类型的数据上(如语音、文本)
  2. 我们的方法对初始化的分布比较敏感,我们会研究其他的初始化策略。

相关资料

写在最后

这篇数据集蒸馏的文章我觉得非常有意思,刷新了我之前的对深度学习的认知,原论文值得好好阅读。


Author: SHWEI
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint polocy. If reproduced, please indicate source SHWEI !
评论
 Previous
重温机器学习概念:偏差(Bias)、方差(Variance)、欠拟合(Underfitting)、过拟合(Overfitting) 重温机器学习概念:偏差(Bias)、方差(Variance)、欠拟合(Underfitting)、过拟合(Overfitting)
写在前面 最近放寒假了,除了看论文,我还打算抽空复习一些机器学习的基础知识。今天主要复习了机器学习中偏差、方差、欠拟合、过拟合这几个概念,能不能讲清楚偏差方差,经常被用来考察面试者的理论基础,我之前对有些地方是一知半解的,比如那个射靶图是什
2022-01-02
Next 
【GAN论文解读系列】NeurIPS 2016 InfoGAN 使用InfoGAN解耦出可解释的特征 【GAN论文解读系列】NeurIPS 2016 InfoGAN 使用InfoGAN解耦出可解释的特征
论文题目:InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets 论文地址:https:/
2021-12-17
  TOC