ICLR 2018

本文严格意义是MTL,而非MTRL

摘要

Multi-task learning (MTL) with neural networks leverages commonalities in tasks to improve performance, but often suffers from task interference which reduces the benefits of transfer. To address this issue we introduce the routing network paradigm, a novel neural network and training algorithm. A routing network is a kind of self-organizing neural network consisting of two components: a router and a set of one or more function blocks. A function block may be any neural net-work – for example a fully-connected or a convolutional layer. Given an input the router makes a routing decision, choosing a function block to apply and passing the output back to the router recursively, terminating when a fixed recursion depth is reached. In this way the routing network dynamically composes different func-tion blocks for each input. We employ a collaborative multi-agent reinforcement learning (MARL) approach to jointly train the router and function blocks. We evaluate our model against cross-stitch networks and shared-layer baselines on multi-task settings of the MNIST, mini-imagenet, and CIFAR-100 datasets. Our experiments demonstrate a significant improvement in accuracy, with sharper con-vergence. In addition, routing networks have nearly constant per-task training cost while cross-stitch networks scale linearly with the number of tasks. On CIFAR-100 (20 tasks) we obtain cross-stitch performance levels with an 85% reduction in training time.

使用神经网络的多任务学习(MTL)利用任务中的共性来提高性能,但往往会受到任务干扰,从而减少迁移的好处。为了解决这个问题,我们引入了路由网络范式,一种新型的神经网络和训练算法。路由网络是一种自组织的神经网络,由两部分组成:一个路由器和一组单个或多个功能块。一个功能块可以是任何神经网络——例如全连接层或卷积层。给定一个输入,路由器做出一个路由决定,选择一个功能块来应用,并将输出递归给路由器,当达到一个固定的递归深度时终止。通过这种方式,路由网络为每个输入动态地组建了不同的功能块。我们采用协作式多智能体强化学习(MARL)方法来联合训练路由器和功能块。我们在MNIST、mini-imagenet和CIFAR-100数据集的多任务设置上对我们的模型与cross-stitch网络和共享层基线进行评估。我们的实验表明,准确度有了明显的提高,一致性更强。此外,路由网络的每个任务训练成本几乎是恒定的,而cross-stitch网络则随着任务数量的增加而线性扩展。在CIFAR-100(20个任务)上,我们获得cross-stitch的性能水平,且训练时间减少了85%。

研究动机

  • 多任务模型需要利用任务的共性(正迁移),同时尽量减少干扰(负迁移)
  • 理想情况下,正迁移由共享的功能模块实现,而负迁移可以通过独立的模块避开

主要贡献

  • 提出一种新的架构:由一个路由器和一组功能模块组成,由路由器选择功能模块和先后顺序,从而使不同任务使用不同路径进行决策
  • 该架构非常通用,路由器有多种方式实现,功能模块可以为任意维度对齐的神经网络结构
  • 使用强化学习训练路由器,将每个任务的路由当作一个智能体的决策

方法描述

Routing 示例

Routing 算法

  • 整体算法

    • 路由器输入:$v$ 表示维度为 $d$ 的表征向量;$t$ 表示任务编码;$i$ 表示路由深度
    • 路由器决策函数 $\text{router}$:$\mathbb{R}^d \times \mathbb{Z}^+ \times \mathbb{Z}^+ \rightarrow \{1,2,\dots,k,\text{PASS} \}$

Router Trainer

训练示例

  • RL训练Router

    • 状态:$(v,t,i)$;动作:$\{1,2,\dots,k,\text{PASS} \}$
    • 奖励设计:

      • 即时奖励:router选择该模块的历史可能性,乘一个超参$\rho \in [0,1]$(鼓励使用尽量少的功能模块,希望训练完后可以最大程度减小模型规模)
      • 最终奖励:预测正确+1;预测错误-1(分类任务)
    • RL算法

      • 单智能体:如(a)所示,一个独立智能体进行路由决策,可使用任何RL算法(PG、Q-Learning)
      • 多智能体:如(b)所示,每个任务由一个独立智能体进行路由决策,使用算法Weighted Policy Learner (WPL)
      • 分层多智能体:如(c)所示,上层dispatching agent决定当前输入由哪个下层智能体进行路由决策
      • PG算法选择REINFORCE;Q算法选择vanilla Q-Learning
      • 由于路由器和功能模块共同训练,所以可以认为环境是非稳定的,WPL则可以在非稳定环境中收敛

        • simplex-projection定义为:$\text{clip}(\pi) / \sum(\text{clip}(\pi))$,其中$\text{clip}(x)=\max(0,\min(1,x))$
        • 仅适用于表格类算法

RL算法

WPL算法

理论分析

实验验证

  • 数据集:multi-task versions of MNIST (MNIST-MTL);Mini-Imagenet (MIN-MTL) ;CIFAR-100 (CIFAR-MTL)
  • 可视化routing选择策略:初始为随机策略(熵较大),逐渐收敛为确定性策略(以近100%概率选择指定模块)

Routing 策略

  • Routing可视化:将MNIST-MTL的routing map可视化,呈现7-4-5的梨形,而非一般认为的共享representation后再使用multi-head独立处理

Routing map

其他思考

  • 本文主要针对MTL,并非MTRL,所以未对实验结果进行详细解读,本文主要将RL用于了routing网络的训练。
  • routing可视化结果的随机性质疑,使用不同的随机种子获得的routing map都是7-4-5的结构吗?该结构也不具备任何的可解释性,作者仅给出可视化图示,也没有进行任何说明,所以此结构的意义还有待进一步探索。
  • 该方法无法端到端进行训练,所以相对而言训练效率较低,具体基于batch的算法实现未给出,例如初始阶段容易出现相同任务不同路径的情况,最简单的处理方式即使用mask阻断梯度进行batch训练。
  • 将其应用于MTRL:需考虑如何对状态进行编码,分类任务可以理解为一步决策任务,但对于RL环境而言,是序列决策任务,所以是以单步状态还是以整个episode进行routing。

原文链接:https://arxiv.org/abs/1711.01239

OpenReview链接:https://openreview.net/forum?id=ry8dvM-R-

参考资料:https://www.cnblogs.com/RyanXing/p/routing_networks.html

Last modification:August 11, 2022
如果觉得我的文章对你有用,请随意赞赏