使用元学习来进行少样本图像分类

2020年12月4日 作者 火狐体育

首发:AI公园公众号
作者:Etienne
编译:ronghuaiyang

导读

你并不总是有足够的图像来训练一个深度神经网络。下面是教你如何通过几个样本让模型快速学习的方法。

你并不总是有足够的图像来训练一个深度神经网络。下面是教你如何通过几个样本让模型快速学习的方法。

我们为什么要关心少样本学习?

1980年,Kunihiko Fukushima开发了第一个卷积神经网络。自那以后,由于计算能力的不断增强和机器学习社区的巨大努力,深度学习算法从未停止提高其在与计算机视觉相关的任务上的性能。2015年,Kaiming He和他的微软的团队报告说,他们的模型在从ImageNet分类图像时比人类表现更好。那时,我们可以说计算机在处理数十亿张图像来解决特定任务方面比我们做得更好。

但是,如果你不是谷歌或Facebook,你不可能总是能够构建具有那么多图像的数据集。当你在计算机视觉领域工作时,你有时不得不对图像进行分类,每个标签只有一个或两个样本。在这场比赛中,人类还是要被打败的。只要给婴儿看一张大象的照片,他们从此以后就可以认出大象了。如果你用Resnet50做同样的事情,你可能会对结果感到失望。这种从少量样本中学习的问题叫做少样本学习。

近年来,少样本学习问题在研究界引起了极大的关注,并且已经开发出了许多优雅的解决方案。目前最流行的解决方案是使用元学习,或者用三个词来概括:learning to learn。如果你想知道元学习是什么以及它是如何工作的,请继续阅读。

少样本图像分类任务

首先,我们需要定义N-way K-shot图像分类任务。给定:

  1. 一个由N个标签组成的支持集,每个标签对应K个有标签的图像
  2. 由Q个查询图像组成的查询集

任务是对查询图像进行分类。当K很小(通常是K<10)时,我们讨论的是少样本图像分类(在K=1的情况下,是单样本图像分类)。

一个少样本分类任务的例子:对于支持集中N=3个类中的每个类,给定K=2个样本,我们希望将查询集中的Q=4只狗标记为Labrador, saint bernard或Pug。即使你从没见过Labrador、saint bernard或Pug,这对你来说也很容易。但要用人工智能解决这个问题,我们需要元学习。

元学习范式

1998年,Thrun & Pratt说,要解决一个任务,一个算法学习“如果性能可以随着经验提升”,同时,给定一族需要解决的问题,一个算法学习”性能随着经验和任务数量提升”。我们将后者称为元学习算法。它不是去学习如何解决一个特定的任务。它可以学会解决许多任务。每学习一项新任务,它就能更好地学习新任务:它学会去学习。

正式的描述一下,如果我们想要解决一个任务T,元学习算法训练一批任务{Tᵢ}。算法通过尝试解决这些任务来得到学习的经验,最终去解决终极任务T。

例如,考虑上图中显示的任务_T_。它包括有标签图像,如Labrador,Saint-Bernard或Pug,使用3×2=6个有标签图像。一个训练任务Tᵢ可能是利用6个有标签图像把图像标记为Boxer, Labradoodle或者Rottweiler。meta-training过程是一连串的这些任务Tᵢ,每一次都是不同品种的狗。我们期望元学习模型“随着经验和任务数量的增加”变得更好。最后,我们在_T_上对模型进行评价。

我们评估了Labradors、Saint-Bernards和Pugs的元学习模型,但我们只在其他品种上训练。

怎么做呢?假设你想要解决这个任务(Labrador,Saint-Bernard 和Pug)。你需要一个元训练数据集,里面有很多不同品种的狗。例如,你可以使用Stanford Dogs数据集,其中包含从ImageNet提取的超过20k只狗。我们将此数据集称为_D_。注意,_D_不需要包含任何Labrador,Saint-Bernard或Pug。

我们从_D_中抽取batch组成episodes。每个episodes 对应于一个N-way K-shot分类任务Tᵢ,(通常我们使用相同的N和K)。模型解决了batch中的所有的episodes后(对查询集中的所有图像打标签),它的参数被更新。这通常是通过对查询集上的分类不准确造成的损失进行反向传播来实现的。

这样,模型跨任务学习,以准确地解决一个新的,看不见的少样本分类任务。标准学习分类算法学习一个映射_图像→标签_,元学习算法学习一个映射:support-set→c(.),其中c是一个映射:query→label。

元学习算法

既然我们知道了算法元训练意味着什么,一个谜仍然存在:元学习模型是如何解决一个少样本的分类任务的?当然,解决方案不止一种。我们聚焦在最流行的方案上。

元学习

度量学习的基本思想是学习数据点(如图像)之间的距离函数。它已经被证明对于解决较少样本的分类任务是非常有用的:度量学习算法不需要对支持集(少量标记图像)进行微调,而是通过与标记图像进行比较来对查询图像进行分类。

查询(右侧)与支持集的每个图像进行比较。它的标签取决于哪些图像最接近。

当然,你不能逐像素地比较图像,所以你要做的是在相关的特征空间中比较图像。为了更清楚一些,让我们详细说明度量学习算法如何解决一个少样本的分类任务(上面定义为一个标签样本的支持集,和一个我们想要分类的图像的查询集):

1、我们从支持和查询集的所有图像中提取嵌入(通常使用卷积神经网络)。现在,我们在少样本分类任务中必须考虑的每一幅图像都可以用一维向量表示。

2、每个查询根据其支持图像集的距离进行分类。距离函数和分类策略都有很多可能的设计选择。一个例子就是欧氏距离和k近邻。

3、在元训练期间,在episode结束时,通过反向传播查询集上分类错误造成的损失(通常是交叉熵损失)来更新CNN的参数。

每年都会发布几种度量学习算法来解决少样本图像分类的两个原因是:

1、它们在经验上很有效;

2、唯一的限制是你的想象力。有许多方法可以提取特征,甚至有更多的方法可以比较这些特征。现在我们将回顾一些现有的解决方案。

匹配网络的算法。特征提取器对于支持集图像(左侧)和查询图像(底部)是不同的。使用余弦相似度将查询的嵌入与支持集中的每幅图像进行比较。然后用softmax对其进行分类。

匹配网络(见上图)是第一个使用元学习的度量学习算法。在这种方法中,我们不以同样的方式提取支持图像和查询图像的特征。来自谷歌DeepMind的Oriol Vinyals和他的团队提出了使用LSTM networks在特征提取期间使所有图像进行交互的想法。称为全上下文嵌入,因为你允许网络找到最合适的嵌入,不仅知道需要嵌入的图像,而且还知道支持集中的所有其他图像。这让他们的模型表现的更好,因为所有的图像都通过了这个简单的CNN,但它也需要更多的时间和更大的GPU。

在最近的研究中,我们没有将查询图像与支持集中的每一张图像进行比较。多伦多大学的研究人员提出了Prototypical Networks。在它们的度量学习算法中,在从图像中提取特征后,我们计算每个类的原型。为此,他们使用类中每个图像嵌入的平均值。(但是你可以想象数以千计的方法来计算这些嵌入。为了反向传播,函数只需是可微的即可)一旦原型被计算出来,查询将使用到原型的欧式距离进行分类(见下图)。

在原型网络中,我们将查询X标记为最接近原型的标签。

尽管简单,原型网络仍然能产生最先进的结果。更复杂的度量学习架构后来被开发出来,比如一个神经网络来表示距离函数(而不是欧氏距离)。这略微提高了精确度,但我相信时至今日,原型的想法是在用于少样本图像分类的度量学习算法领域中最有价值的想法(如果你不同意,请留下愤怒的评论)。

模型无关元学习

我们将以模型无关元学习 (MAML)来结束这次回顾,这是目前最优雅、最有前途的元学习算法之一。它基本上是最纯粹的元学习,通过神经网络有两个层次的反向传播。

该算法的核心思想是训练神经网络的参数,可以适应快速和较少的例子,以新的分类任务。下面我将为你提供一个关于MAML如何在一个episode中进行元训练的可视化例子(例如,在从D中采样得到的Tᵢ上进行few-shot分类任务)。假设你有一个神经网络M参数为𝚯:

MAML模型的meta-training步骤,参数为𝚯

1、创建一个副本M(这里叫_f_),和初始化为𝚯(图中,𝜽₀=𝚯)。

2、在支持集上快速微调_f_(只进行几次梯度下降)。

3、对查询集上应用调优后的_f_。

4、对整个过程中的分类损失进行反向传播,更新𝚯。

然后,在下一个episode中,我们创建更新后的模型_M_的副本,在一个新的少样本分类任务上运行流程。

在元训练过程中,MAML学习了初始化参数,允许模型快速有效地适应一个新的少样本的任务,新的,看不见的类。

公平地说,目前在流行的少样本图像分类的基准上,MAML并不像度量学习算法那样有效。它很难训练因为有两个层次的训练,所以超参数搜索要复杂得多。另外,元反向传播意味着需要计算梯度的梯度,所以你必须使用近似才能在标准的gpu上训练它。由于这些原因,你可能更愿意使用度量学习的算法。

但模型无关元学习之所以如此令人兴奋,是因为它是模型无关的。这意味着它实际上可以应用于任何神经网络,任何任务。掌握MAML意味着能够训练任何神经网络以快速适应新的任务,而且需要的样本很少。MAML的作者Chelsea Finn和Sergey Levine将其应用于监督少样本分类、监督回归和强化学习。但是有了想象力和努力工作,你可以使用它将任何神经网络转换为一个少样本高效的神经网络!

—END—

英文原文:https://www.sicara.ai/blog/20…

推荐阅读

  • TensorPipe:支持最先进增增强和底层优化的Tensorflow的高性能数据Pipeline
  • OCR in the Wild:文本检测和识别的SOTA
  • 知识蒸馏:如何用一个神经网络训练另一个神经网络

关注图像处理,自然语言处理,机器学习等人工智能领域,请点击关注AI公园专栏