元学习模型MAML和Reptile详解


写在前面

记得研究生一年级的时候,每次开组会讲论文,实验室的师兄师姐经常提到元学习以及MAML这些概念。由于我当时比较懒,也觉得我研究方向不是这个,就没有细想,一知半解,只是知道有这个概念。后来我发现很多知识是相联系的,比如了解NLP的常用模型,对你理解其他领域的模型也是很有帮助的。所以深入了解元学习和MAML这些知识也是有助于我自己的科研工作的。于是我花了点时间好好学习了下元学习,并写下来这篇博客。

什么是元学习

对于一个分类任务,我们构建好深度学习模型后,学习的就是模型的参数,学习的目的就是使得最终的参数能够在训练集上达到最佳的精度,损失最小。

但是元学习(Meta-Learning)面向的是学习的过程,并不是学习的结果,也就是元学习不需要学出来最终的模型参数,学习的更像是学习技巧这种东西(这就是为什么叫做learning to learn)。他不是为了解决具体某项具体的任务,而是研究如何提升模型解决一系列任务的能力。

下面这个对比传统机器学习和元学习的例子,来自博客:Meta Learning 入门:MAML 和 Reptile,我觉得讲得很不错,值得阅读。

  • 如果把训练算法类比成学生在学校的学习,那么传统的机器学习任务对应的是不同科目,例如数学、语文、英语,每个科目上训练一个模型。而 Meta Learning 则是要提升一个学生整体的学习能力,让学生学会学习(就是所谓的 learn to learn)。就像所有的学生都上一样的课,做一样的作业,可偏偏有的学生各科成绩都好,有的学生偏科,而有的学生各科成绩都差。
  • 各科成绩都好的学生,说明他大脑 Meta Learning 的能力强,可以迅速适应不同科目的学习任务。
  • 而对于偏科的学生,他们大脑的 Meta Learning 能力就相对弱一些,只能学习某项具体的任务,换个任务就不 work 了。对这种学生,老师的建议一般是:“在弱科上多花一点时间”,可这么做是有风险的,最糟糕的一种情况是:弱势科目没学好,强势科目成绩反而下降了。可以看到,现如今大多数深度神经网络都是“偏科生”,且不说分类、回归这样差别较大的任务对应的网络模型完全不同,即使同样是分类任务,把人脸识别网络架构用在分类 ImageNet 数据上,就未必能达到很高的准确率。
  • 至于各科成绩都差的学生,说明他们不但 Meta Learning 能力弱,在任何科目上的学习能力都弱,需要被老师重点关照……

元学习的方法有很多,有些是针对不同的训练任务,输出不同的模型结构和超参数,例如AutoML,这些算法比较复杂,本文将介绍元学习中两种常用的模型:MAML和Reptile,它们不需要改变模型的结构,只改变模型的初始化参数

一、MAML

MAML:Model Agnostic Meta-Learning for Fast Adaptation of Deep Networks, ICML 2017, Paper, Code

摘要:The goal of meta-learning is to train a model on a variety of learning tasks, such that it can solve new learning tasks using only a small number of training samples. In our approach, the parameters of the model are explicitly trained such that a small number of gradient steps with a small amount of training data from a new task will produce good generalization performance on that task.

1.1 MAML的目标

MAML,全称呼叫做Model-Agnostic Meta-Learning ,意思是模型无关的元学习。所以MAML并不是一个深度学习模型(如CNN、RNN等),而更像是一种训练技巧。

模型参数初始化

(1)通常来说,我们的深度学习模型参数初始化方法是随机初始化(从高斯分布中采样),Xavier初始化,He初始化等,这样的初始化方法一般很难直接找到一个好的初始化参数。

(2)我们还可以用预训练来初始化模型的参数。用预训练初始化参数的神经网络本身就有很强的特征提取能力,能够提取很多有含义的特征,例如耳朵,鼻子,眼睛,毛发。分辨猫狗,只需要知道这些特征是如何组合的就好了,这比从头开始学习如何提取耳朵、鼻子等特征要高效得多。

利用预训练的网络进行参数初始化,相当于赋予了网络很多先验知识。类比我们人类,让一个小学没毕业的人去听高等数学,显然他是无法听懂的;而让一个高考数学满分的高中毕业生去听,他可能要学得轻松得多。如果忽略智商因素,我们人类的大脑从结构上说都是大同小异,为啥表现差别那么大呢?因为它们积累的知识量不同,后者积累的知识更多,也就是常说的“基础扎实”,换成神经网络的术语,就是后者的网络只需要 fine tune 一下就好了,而前者需要 train from scratch ,要补很多课才行。

(3)现在MAML要做的事情是学习一个“好”的初始化参数。以前我们是训练一个模型,然后让这个模型的参数$\theta$最优,而现在我们训练MAML是希望初始化参数$\phi$最优,这样就可以实现“快速学习”(使用来自新任务的少量数据就能解决学习任务,而且只需要几步梯度下降就能得到好的泛化效果)。

MAML积累的知识是元知识,也就是学习技巧,这使他比随机初始化、预训练更高级!你可以把学习一个“好”的初始化参数的过程理解成掌握一个好的学习技巧。有了这个学习技巧后,你就可以快速地解决新的任务。


MAML是与模型无关的,即该方法既可以用在CNN上,也可以用在RNN上,甚至可以用到强化学习上。但是MAML在用的时候是固定模型的,也就是说不同task $\hat \theta^n$对应的模型是相同的!我们是希望通过Meta-Learning的方式学习出这个模型的一个“好”的初始化参数$\phi$,有了这个初始化参数$\phi$后,我们只需要少量的样本就可以快速在这个模型上进行收敛。

MAML是learning to learn,所以他的输入不再是一条条单纯的数据了,而是一个个的任务(task)。好比人们在区分物体之前,已经看过了很多不同的区分任务(task),如猫狗分类、自行车和汽车分类、苹果和橘子分类等,这些都是一个个的任务(task),你可以把他们看作训练MAML的一个个样本。

MAML的损失函数$L(\phi) = \sum_{n=1}^Nl^n(\hat \theta^n)$,其中$l^n(\hat \theta^n)$是task n(经过训练后参数为$\hat \theta^n$)在test data(Query Set)上的损失值


在MAML中这个F就是初始化参数$\phi$,$f^1$就是$\hat \theta^1$

1.2 元学习的data

机器学习的数据

  • Train Data(一条条数据)
  • Test Data(一条条数据)

元学习的数据

  • Train Data(一个个任务)
    • 每个任务(task)有自己的训练集(Support Set)和测试集(Query Set)
  • Test Data(一个个任务)
    • 每个任务(task)有自己的训练集(Support Set)和测试集(Query Set)

1.3 MAML详解

在MAML的实际应用中,每次采样一个任务(MAML的一个样本),其参数为$\hat \theta^n$,从$\phi$到$\hat \theta^n$的训练过程只会做一个参数更新。尽管在传统的模型训练时,我们的参数会更新成千上万次。但是在MAML中我们假设这个过程参数只会被更新一次。不过我们的$\phi$是会更新很多次的。

1.3.1 MAML训练阶段

Task n的参数更新过程:$\phi_t \rightarrow \hat{\theta}^{n}$

$\hat{\theta}^{n}= \phi-\epsilon \nabla_{\phi} l^n(\phi)\textit{, loss on support set}$

MAML的参数更新过程:$\phi_t \rightarrow \phi_{t+1}$

$\begin{aligned}
\phi & \leftarrow \phi-\eta \nabla_{\phi} L(\phi) \
&=\phi-\eta \nabla_{\phi} \sum_{n=1}^{N} l^{n}\left(\hat{\theta}^{n}\right)\textit{, loss on query set}\
&\approx \phi-\eta \nabla_{\hat{\theta}^{n}} \sum_{n=1}^{N} l^{n}\left(\hat{\theta}^{n}\right)\textit{, loss on query set}
\end{aligned}$

理论上,MAML是用第二行进行参数更新的,但实际上,做MAML实验时,为了实现的方便,MAML用了第三行的一阶近似。

每次更新的过程:N个Task n的参数更新和一次MAML的参数更新

1.3.2 伪代码: MAML for Few-Shot Supervised Learning

1.3.3 MAML推理阶段

1.4 MAML数据集Omniglot

https://github.com/dragen1860/MAML-Pytorch

Omniglot是元学习中常用的数据集,在MAML的实验中也用了这个数据集

Omniglot数据集有1623个类别,每个类别有20个样本。https://github.com/brendenlake/omniglot

Omniglot的用法是这样的,从其中采样 N 个类,每个类有 K 个训练两本,组成一个训练任务(task),称为 N-ways K-shot classification。然后再从剩下的类中,继续重复上一步的采样,构建第二个 task,最终构建了 m 个 task。把这 m 个 task 分成训练 task 和测试 task,在训练 task 上训练 Meta Learning 的算法,然后再用测试 task 评估 Meta Learning 得到的算法的学习能力。

N-ways K-shot classification: In each training and test tasks, there are N classes, each has K examples.

二、模型预训练

请添加图片描述

2.1 重点:MAML和Pre-Train的异同点

相同点: MAML和模型预训练都是在找一个好的初始化参数$\phi$

不同点: MAML和模型预训练和评判“好的初始化参数”的标准不一样

举个例子:

  • MAML:好比读博士,比较看潜力,可能现在没什么钱,但是读完博士后工资会很高
  • 模型预训练:好比现在就找工作,比较看当前,我现在就想去赚钱,虽然工资上限可能没有博士高

我们始终要牢记,Meta Learning最终的目的是要让模型获得一个良好的初始化参数。这个初始化参数$\phi$在训练 task 上表现或许并不出色(因为$\phi$没有直接按照梯度的方向走),但以这个参数$\phi$为起点,去学习新的 task 时,学得会又快又好。而模型预训练,则是着眼于解决当前的 task,不会考虑如何面对新的 task。

分析:Model Pre-training认为$\phi$拿去做task1和task2的表现要很强,但是并不保证$\phi$用task1的数据和task2的数据拿去做训练后,可以变得很强。


MAML不关心这些task在初始化参数上的表现,这不是重点,我们不在乎$\phi$现在的表现,而是在乎$\phi$经过训练后的表现。

分析:MAML认为虽然$\phi$本身拿去做task1和task2的表现可能都不是很强,但是$\phi$用task1的数据和task2的数据拿去做训练后,可以变得很强,那他就是一个好的初始化参数$\phi$

三、Reptile

Reptile:On First-Order Meta-Learning Algorithms, arXiv, 2018, Paper

四、MAML、Pre-Train和Reptile对比图


下图来自:https://www.cnblogs.com/kailugaji/p/15156806.html

五、参考资料


Author: SHWEI
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint polocy. If reproduced, please indicate source SHWEI !
评论
 Previous
Ubuntu 20.04通过docker安装微信和QQ Ubuntu 20.04通过docker安装微信和QQ
前言(๑•̀ㅂ•́)و✧ Ubuntu上的微信和QQ一直很难装,我之前尝试了很多方法(有些是基于网页版登录微信的,有些是用wine的),但我试了都不太行,坑点很多,搞不好就把系统搞崩了。今天发现用docker安装微信和QQ非常简单,所以想分
2022-03-04
Next 
重温机器学习概念:偏差(Bias)、方差(Variance)、欠拟合(Underfitting)、过拟合(Overfitting) 重温机器学习概念:偏差(Bias)、方差(Variance)、欠拟合(Underfitting)、过拟合(Overfitting)
写在前面 最近放寒假了,除了看论文,我还打算抽空复习一些机器学习的基础知识。今天主要复习了机器学习中偏差、方差、欠拟合、过拟合这几个概念,能不能讲清楚偏差方差,经常被用来考察面试者的理论基础,我之前对有些地方是一知半解的,比如那个射靶图是什
2022-01-02
  TOC