面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”

面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”

人脑显然是人工智能追求的最高标准。

毕竟人脑使得人类拥有了连续学习的能力以及情境依赖学习的能力。

这种可以在新的环境中不断吸收新的知识和根据不同的环境灵活调整自己的行为的能力,也正是深度学习系统与人脑相差甚远的重要原因。

想让传统深度学习系统获得连续学习能力,最重要的是克服人工神经网络会出现的“灾难性遗忘”问题,即一旦使用新的数据集去训练已有的模型,该模型将会失去对原数据集识别的能力。

换句话说就是:让神经网络在学习新知识的同时保留旧知识。

前段时间,来自苏黎世联邦理工学院以及苏黎世大学的研究团队发表了一篇名为《超网络的连续学习》(Continual learning with hypernetworks)的研究。提出了任务条件化的超网络(基于任务属性生成目标模型权重的网络)。该方法能够有效克服灾难性的遗忘问题。

面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”

具体来说,该方法能够帮助在针对多个任务训练网络时,有效处理灾难性的遗忘问题。除了在标准持续学习基准测试中获得最先进的性能外,长期的附加实验任务序列显示,任务条件超网络(task-conditioned hypernetworks )表现出非常大的保留先前记忆的能力。

hypernetworks

在苏黎世联邦理工学院以及苏黎世大学的这项工作中,最重要的是对超网络(hypernetworks)的应用,在介绍超网络的连续学习之前,雷锋网(公众号:雷锋网) AI科技评论先对超网络做一下介绍。hyperNetwork是一个非常有名的网络,简单说就是用一个网络来生成另外一个网络的参数。

工作原理是:用一个hypernetwork输入训练集数据,然后输出对应模型的参数,最好的输出是这些参数能够使得在测试数据集上取得好的效果。简单来说hypernetwork其实就是一个meta network。雷锋网 AI科技评认为传统的做法是用训练集直接训练这个模型,但是如果使用hypernetwork则不用训练,抛弃反向传播与梯度下降,直接输出参数,这等价于hypernetwork学会了如何学习图像识别。

面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”

论文下载见文末

在《hypernetwork》这篇论文中,作者使用 hyperNetwork 生成 RNN 的权重,发现能为 LSTM 生成非共享权重,并在字符级语言建模、手写字符生成和神经机器翻译等序列建模任务上实现最先进的结果。超网络采用一组包含有关权重结构的信息的输入,并生成该层的权重,如下图所示。

面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”

超网络生成前馈网络的权重:黑色连接和参数与主网络相关联,而橙色连接和参数与超网络相关联。

超网络的连续学习模型

在整个工作中,雷锋网 AI科技评发现作者首先假设输入的数据面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”,......面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”是可以被储存的,并能够使用输入的数据计算面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”。另外,可以将未使用的数据和已经使用过数据进行混合来避免遗忘。假设F(X,Θ)是模型,那么混合后的数据集为{(面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”),。。。,(面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”),(面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”)},其中其中Yˆ(T)是由模型f(.,面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”)生成的一组合成目标。然而存储数据显然违背了连续学习的原则,所以在在论文中,作者提出了一种新的元模型fh(面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”,面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”)做为解决方案,新的解决方案能够将关注点从单个的数据输入输出转向参数集{面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”},并实现非储存的要求。这个元模型称为任务条件超网络,主要思想是建立任务面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”和权重面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”的映射关系,能够降维处理数据集的存储,大大节省内存。

在《超网络的连续学习》这篇论文中,模型部分主要有3个部分,第一部分是任务条件超网络。首先,超网络会将目标模型参数化,即不是直接学习特定模型的参数,而是学习元模型的参数,从而元模型会输出超网络的权重,也就是说超网络只是权重生成器。

面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”

图a:正则化后的超网络生成目标网络权重参数;图b:迭代地使用较小的组块超网络产生目标网络权重。

然后利用带有超网络的连续学习输出正则化。在论文中,作者使用两步优化过程来引入记忆保持型超网络输出约束。首先,计算∆Θh(∆Θh的计算原则基于优化器的选择,本文中作者使用Adam),即找到能够最小化损失函数的参数。损失函数表达式如下图所示:

 面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”

注:Θ∗ h是模型学习之前的超网络的参数;∆Θh为外生变量;βoutput是用来控制正则化强度的参数。

然后考虑模型的面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”,它就像面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”一样。在算法的每一个学习步骤中,需要及时更新,并使损失函数最小化。在学习任务之后,保存最终面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”e并将其添加到集合{面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”}。

模型的第二部分是用分块的超网络进行模型压缩。超网络产生目标神经网络的整个权重集。然而,超网络可以迭代调用,在每一步只需分块填充目标模型中的一部分。这表明允许应用较小的可重复使用的超网络。有趣的是,利用分块超网络可以在压缩状态下解决任务,其中学习参数(超网络的那些)的数量实际上小于目标网络参数的数量。

为了避免在目标网络的各个分区之间引入权重共享,作者引入块嵌入的集合{面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”} 作为超网络的附加输入。因此,目标网络参数的全集Θ_trgt=[面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”,,,面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”]是通过在面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”上迭代而产生的,在这过程中保持面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”不变。这样,超网络可以每个块上产生截然不同的权重。另外,为了简化训练过程,作者对所有任务使用一组共享的块嵌入。

模型的第三部分:上下文无关推理:未知任务标识(context-free inference: unknown task identity)。从输入数据的角度确定要解决的任务。超网络需要任务嵌入输入来生成目标模型权重。在某些连续学习的应用中,由于任务标识是明确的,或者可以容易地从上下文线索中推断,因此可以立即选择合适的嵌入。在其他情况下,选择合适的嵌入则不是那么容易。

作者在论文中讨论了连续学习中利用任务条件超网络的两种不同策略。

策略一:依赖于任务的预测不确定性。神经网络模型在处理分布外的数据方面越来越可靠。对于分类目标分布,理想情况下为不可见数据产生平坦的高熵输出,反之,为分布内数据产生峰值的低熵响应。这提出了第一种简单的任务推理方法(HNET+ENT),即给定任务标识未知的输入模式,选择预测不确定性最小的任务嵌入,并用输出分布熵量化。

策略二:当生成模型可用时,可以通过将当前任务数据与过去合成的数据混合来规避灾难性遗忘。除了保护生成模型本身,合成数据还可以保护另一模型。这种策略实际上往往是连续学习中最优的解决方案。受这些成功经验的启发,作者探索用回放网络(replay network)来增强深度学习系统。

合成回放(Synthetic replay)是一种强大但并不完美的连续学习机制,因为生成模式容易漂移,错误往往会随着时间的推移而积累和放大。作者在一系列关键观察的基础上决定:就像目标网络一样,重放模型可以由超网络指定,并允许使用输出正则化公式。而不是使用模型自己的回放数据。因此,在这种结合的方法中,合成重放和任务条件元建模同时起作用,避免灾难性遗忘。

基准测试

作者使用MNIST、CIFAR10和CIFAR-100公共数据集对论文中的方法进行了评估。评估主要在两个方面:(1)研究任务条件超网络在三种连续学习环境下的记忆保持能力,(2)研究顺序学习任务之间的信息传递。具体的在评估实验中,作者根据任务标识是否明确出了三种连续学习场景:CL1,任务标识明确;CL2,任务标识不明确,并不需明确推断;CL3,任务标识可以明确推断出来。另外作者在MNIST数据集上构建了一个全连通的网络,其中超参的设定参考了van de Ven & Tolias (2019)论文中的方法。在CIFAR实验中选择了ResNet-32作为目标神经网络。

van de Ven & Tolias (2019):

Gido M. van de Ven and Andreas S. Tolias. Three scenarios for continual learning. arXiv preprint arXiv:1904.07734, 2019.

为了进一步说明论文中的方法,作者考虑了四个连续学习分类问题中的基准测试:非线性回归,PermutedMNIST,Split-MNIST,Split CIFAR-10/100。

非线性回归的结果如下:

面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”

注:图a:有输出正则化的任务条件超网络可以很容易地对递增次数的多项式序列建模,同时能够达到连续学习的效果。图b:和多任务直接训练的目标网络找到的解决方案类似。图c:循序渐进地学习会导致遗忘。

在PermutedMNIST中,作者并对输入的图像数据的像素进行随机排列。发现在CL1中,任务条件超网络在长度为T=10的任务序列中表现最佳。在PermutedMNIST上任务条件超网络的表现非常好,对比来看突触智能(Synaptic Intelligence) ,online EWC,以及深度生成回放( deep generative replay)方法有差别,具体来说突触智能和DGR+distill会发生退化,online EWC不会达到非常高的精度,如下图a所示。综合考虑压缩比率与任务平均测试集准确性,超网络允许的压缩模型,即使目标网络的参数数量超过超网络模型的参数数量,精度依然保持恒定,如下图b所示。

面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”

Split-MNIST作为另一个比较流行的连续学习的基准测试,在Split-MNIST中将各个数字有序配对,并形成五个二进制分类任务,结果发现任务条件超网络整体性能表现最好。另外在split MNIST问题上任务重叠,能够跨任务传递信息,并发现该算法收敛到可以产生同时解决旧任务和新任务的目标模型参数的超网络配置。如下图所示

面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”

图a:即使在低维度空间下仍然有着高分类性能,同时没有发生遗忘。图b:即使最后一个任务占据着高性能区域,并在远离嵌入向量的情况下退化情况仍然可接受,其性能仍然较高。

在CIFAR实验中,作者选择了ResNet-32作为目标神经网络,在实验过程中,作者发现运用任务条件超网络基本完全消除了遗忘,另外还会发生前向信息反馈,这也就是说与从初始条件单独学习每个任务相比,来自以前任务的知识可以让网络表现更好。

综上,在论文中作者提出了一种新的连续学习的神经网络应用模型--任务条件超网络,该方法具有可灵活性和通用性,作为独立的连续学习方法可以和生成式回放结合使用。该方法能够实现较长的记忆寿命,并能将信息传输到未来的任务,能够满足连续学习的两个基本特性。

参考文献:

HYPERNETWORKS:

https://arxiv.org/pdf/1609.09106.pdf

CONTINUAL LEARNING WITH HYPERNETWORKS

https://arxiv.org/pdf/1906.00695.pdf

https://mp.weixin.qq.com/s/hZcVRraZUe9xA63CaV54Yg

雷锋网原创文章,未经授权禁止转载。详情见转载须知

面向超网络的连续学习:新算法让人工智能不再“灾难性遗忘”

(完)