DeepMind

摘要

We propose Algorithm Distillation (AD), a method for distilling reinforcement learning (RL) algorithms into neural networks by modeling their training histories with a causal sequence model. Algorithm Distillation treats learning to reinforcement learn as an across-episode sequential prediction problem. A dataset of learning histories is generated by a source RL algorithm, and then a causal transformer is trained by autoregressively predicting actions given their preceding learning histories as context. Unlike sequential policy prediction architectures that distill post-learning or expert sequences, AD is able to improve its policy entirely in-context without updating its network parameters. We demonstrate that AD can reinforcement learn in-context in a variety of environments with sparse rewards, combinatorial task structure, and pixel-based observations, and find that AD learns a more data-efficient RL algorithm than the one that generated the source data.

我们提出了算法蒸馏(AD),这是一种将强化学习(RL)算法蒸馏成神经网络的方法,通过用因果序列模型对其训练历史进行建模。算法蒸馏将强化学习视为一个跨剧情的顺序预测问题。一个学习历史的数据集由源RL算法生成,然后通过自回归预测行动来训练因果转化器,并将其之前的学习历史作为背景。与蒸馏后学习或专家序列的顺序策略预测架构不同,AD能够在不更新网络参数的情况下完全在上下文中改进其策略。我们证明了AD能够在各种具有稀疏奖励、组合任务结构和基于像素的观察的环境中进行语境强化学习,并发现AD学习的RL算法比产生源数据的算法更具有数据效率。

研究动机

  • Gato和MGDT这类offline(模仿学习)策略蒸馏方法无法trail-and-error的提升策略,此类方法通过策略蒸馏学习策略,但并非强化学习方法(这里的策略蒸馏指从离线数据蒸馏,而非教师模型)
  • 其原因主要是训练数据并不能表示学习的过程

主要贡献

  • 是第一个通过序列建模离线数据的in-context强化学习方法

方法描述

POMDP:部分可观测马尔可夫决策过程 -> 需要历史信息辅助观测

In-Context Learning:从上下文推断任务的能力,通常指通过prompt推断任务

AD整体结构

  • 历史经验history:$\mathcal{H} \ni h_t:=\left(o_0, a_0, r_0, \ldots, o_{t-1}, a_{t-1}, r_{t-1}, o_t, a_t, r_t\right)=\left(o \leq t, r \leq t, a_{\leq t}\right)$
  • algorithm (long history-conditioned policy):$P: \mathcal{H} \cup \mathcal{O} \rightarrow \Delta(\mathcal{A})$

    • Long需要足够跨越学习更新,例如across episodes,把训练过程的episode拼接在一起
    • $\Delta(\mathcal{A})$表示动作空间的概率分布空间
  • 算法$P$生成任务$\mathcal{M}$的历史表示为:$\left(O_0, A_0, R_0, \ldots, O_T, A_T, R_T\right) \sim P_{\mathcal{M}}$
  • 通过行为克隆将long history-conditioned policies蒸馏到一个神经网络中
  • 数据集:$\mathcal{D}:=\left\{\left(o_0^{(n)}, a_0^{(n)}, r_0^{(n)}, \ldots, o_T^{(n)}, a_T^{(n)}, r_T^{(n)}\right) \sim P_{\mathcal{M}_n}^{\text {source }}\right\}_{n=1}^N$

    • $P^{\text {source}}$表示生成数据集的源算法
  • 损失函数(negative log likelihood):$\mathcal{L}(\theta):=-\sum_{n=1}^N \sum_{t=1}^{T-1} \log P_\theta\left(A=a_t^{(n)} \mid h_{t-1}^{(n)}, o_t^{(n)}\right)$
  • 直觉上,该序列模型可以学到源算法的包括探索、时间信用分配等复杂行为
  • 数据集生成:训练N个独立的单任务RL算法,为防止过拟合N个任务是随机从任务分布中采样的,保存所有历史

    • 可能存在有重复任务?
  • 训练序列模型:可以使用任意序列模型进行训练(包括RNN),每次采样长度为$c<T$的子序列进行训练

算法

理论分析

实验验证

  • 实验环境

    • Adversarial Bandit:10 arms的多臂老虎机

      • out of distribution:训练集:95%的收益在奇数臂;测试集:95%的收益在偶数臂
    • Dark Room:二维离散的POMDP,要求智能体到达指定目标位置

      • 智能体只知道自身位置,必须通过奖励推测目标位置
      • 简单:9*9地图,上下左右不动五个动作,20步,出生在中心点,每次到达目标点都获得奖励
      • 困难:17*17地图,只有第一次到达目标时获得奖励
    • Dark Key-to-Door:与Dark Room类似,要求先找钥匙再开门

      • agent只知道自己位置,钥匙和门未知,9*9地图,出生在任意位置,获取钥匙和开门获得奖励,步长50
    • DMLab Watermaze:三维视觉DMLab环境,基于经典的Morris Watermaze

      • 在水下迷宫找出口,迷宫的墙上有利于记住位置的彩色方块
      • 观测:72*96*3图像;动作:8个,前后左右移动、视角左右移动、前进并视角左右移动
      • 奖励:稀疏奖励,仅到达获取奖励
      • 步长50,出生在地图中央
  • Baselines

    • Expert Distillation (ED):数据只包含专家数据
    • Source Algorithm:产生数据的源RL算法from scratch训练,作为data-efficiency对比
    • RL^2:online元强化学习算法,作为AD的上界
  • 评测标准 adapt in-context

    • 主要测试Fast-Adaptation:将pre-train的model在新环境在不断更新自己的context
  • 实验效果

Adversarial Bandit实验效果

  • In-context reinforcement learning:

    • Credit-assignment:对应DarkRoom每次到达目标都获得奖励的情况,尽管模型输入只有单步奖励,但是智能体能学会如何累计奖励
    • Exploration:对应DarkRoom只有第一次到达目标获得奖励的情况,可以从历史episode中推测目标位置
    • Generalization:Dark Key-to-Door一共包含6.5k个任务,少于2k在训练中见过,其他任务基本可以泛化到near-optimal
    • Data-Efficient

数据效率验证

  • prompting with demonstrations (好像是图片显示有点问题)

prompting

  • context size

context size

其他思考

  • 本文的方法很简单,主要是将RL^2 和Transformer结合在一起,RL^2是将Episode串在一起训练RNN,AD是将Episode串在一起训练Transformer。
  • 文章把该方法说的很高级,不变动网络参数即可实现新环境的Adaptation,从context size的对比试验可以直观理解为如果一旦智能体探索到奖励,并在context的历史信息中有保存,则智能体在新的Episode中可以尽可能模仿之前的有效经验,并在此基础上再进行探索。
  • 但是本文只在一些很简单的环境中应用AD取得了较好的效果,单局时长也很短,所以历史探索经验相对简单,且context size可以轻松包含1个或多个episode,应用于复杂环境可能效果不太理想。
  • 本文提供的思路其实是来源于RL^2,即想办法利用智能体的历史探索经验,从而实现Learn to learn,但是做法过于暴力,还可以进一步研究。

原文链接:https://arxiv.org/abs/2210.14215

Last modification:December 25, 2022
如果觉得我的文章对你有用,请随意赞赏