Deepmind

image-20220718101417491

摘要

Inspired by progress in large-scale language modeling, we apply a similar approach towards building a single generalist agent beyond the realm of text outputs. The agent, which we refer to as Gato, works as a multi-modal, multi-task, multi-embodiment generalist policy. The same network with the same weights can play Atari, caption images, chat, stack blocks with a real robot arm and much more, deciding based on its context whether to output text, joint torques, button presses, or other tokens. In this report we describe the model and the data, and document the current capabilities of Gato.

受大规模语言建模进展的启发,我们采用类似的方法来建立一个超越文本输出领域的通用智能体。我们将该智能体称之为Gato,是一个多模式、多任务、多内容的通用策略。具有相同权重的同一网络可以玩雅达利游戏、给图像加标题、聊天、用真正的机器人手臂堆积积木等等,根据其上下文决定是否输出文本、关节扭力、按按钮或其他token。在这份报告中,我们描述了模型和数据,并记录了Gato的当前能力。

研究动机

  • 通用智能体:减少每个专门领域的人工策略设计;
  • Transformer:能处理任何可序列化的数据,且增加数据总量和数据多样性。

主要贡献

  • 同一网络、统一参数的模型可以完成604种不同的任务。
  • 证明了CV、NLP和RL的结合是切实可行的,通过序列预测能够解决一些决策智能的问题。
  • 朴素Gato模型的RL部分使用监督学习,但理论上其可以用于离线和在线强化学习。
  • Gato模型目前的参数量为1.2B,只能算中等,且实验证明随模型参数、训练数据增大,模型性能也能相应提升。

方法描述

image-20220718100521736

主要步骤:数据处理(序列化token)+ 模型(Transformer) + 损失函数 + 部署

  • Tokenization

    • tokenization scheme(表征)

      • 文本:用SentencePiece进行分词,并将32000个分词编码为$[0,32000)$整数
      • 图像:类似ViT将图片切分为无重叠的$16 \times 16$的patch,将patch内的像素归一化到$[-1,1]$并除以patch size的平方根($\sqrt{16}=4$)
      • 离散值:按行优先顺序编码为$[0,1024)$整数
      • 连续值:按行优先顺序编码为浮点值,并进行归一化和离散化
    • sequence ordering(序列化)

      • Episode按照时间步顺序排列
      • 观测Observation($[y_{1:k}, x_{1:m}, z_{1:n}]$):

        1. 文本($y_{1:k}$):原始输入顺序
        2. 图像($x_{1:m}$):栅格顺序(一行一行扫描)
        3. 张量($z_{1:n}$):row-major order行优先顺序
      • 间隔符Separator($'|'$):标识观测结束
      • 动作($a_{1:A}$):row-major order行优先顺序
    • 完整序列表示为:$s_{1: L}=\left[\left[y_{1: k}^{1}, x_{1: m}^{1}, z_{1: n}^{1},'|', a_{1: A}^{1}\right], \ldots,\left[y_{1: k}^{T}, x_{1: m}^{T}, z_{1: n}^{T},'|', a_{1: A}^{T}\right]\right]$

      • $L=T(k+m+n+1+A)$
  • Embedding input tokens and setting output targets

    • embedding function

      • 文本、离散值或连续值观测和动作:查表获得编码vector;根据token位置编码位置vector。
      • 图像:ResNet将每个patch编码为vector;图像内位置编码vector。
    • Target token

      • 将token作为前一个token的预测目标。目前只将文本、离散或连续值、动作作为预测目标。

image-20220720114252934

image-20220720114214137

image-20220720114334098

  • Training

    • 损失函数:$\mathcal{L}(\theta, \mathcal{B})=-\sum_{b=1}^{|\mathcal{B}|} \sum_{l=1}^{L} m(b, l) \log p_{\theta}\left(s_{l}^{(b)} \mid s_{1}^{(b)}, \ldots, s_{l-1}^{(b)}\right)$

      • $m(b,l)$:mask函数;$\mathcal{B}$:训练batch;$L$:sequence长度
    • 网络结构:Token Embedding + Sequence Model
    • 上下文属性:prompt conditioning

      • 25%的数据有prompt序列:由相同智能体完成同一任务的episode构成(其中50%选取末端数据,50%随机选取数据)
    • 训练成本:16x16 TPU v3 训练 1M步,batch size为512,sequence长度为1024,训练4天。
  • Deployment

image-20220718160924039

理论分析

实验验证

image-20220718161755110

  • 数据集:

    • Simulated control tasks:由SoTA或near-SoTA的RL方法生成数据;数据挑选训练过程的最好数据子集$\max _{j \in[0,1, \ldots, N-W]}\left(\sum_{i=j}^{j+M-1} \frac{R_{i}}{W}\right)$
    • Vision and language:采样五对(图像、文本),然后补全或随机截取sequence长度
    • Robotics - RGB Stacking Benchmark (real and sim) :

image-20220718164449935

  • 实验结果:

    • Simulated control tasks:450个超越50%专家分数阈值,200左右超越专家水平
    • Robotics技能泛化能力:有五类形状的物体在训练集中未出现,仅在测试集出现,测试结果超越当前SoTA算法。(不清楚新任务上的prompt如何处理,是否引入额外专家数据)
    • Text samples

image-20220718170124913

image-20220718170146237

  • 实验分析:

    • 随着模型规模的增加,性能也会提升。(受限于Robotics真实环境的20Hz控制频率,无法继续扩展模型)
    • 经过模型微调在其他任务上的表现。

image-20220718172137681

  • 在Robotics上进行微调的结果

image-20220718172951847

其他思考

  • 个人觉得本文最大的亮点是将所有类型的数据都进行了Tokenization和Sequence Ordering,是一种独特的特征处理方式。下面主要针对Tensor类型观测讨论,先将每个特征进行独立编码操作(即Tokenization),再按指定顺序进行排序,个人认为该顺序影响不大,位置编码的主要作用是指示不同的特征,即不同时间不下,同一位置编码可能都表示位置信息。但该顺序存在的影响为sequence长度的截断,即执行动作前L个token中,第一个token并不一定是观测位置编码为p0的token。
  • 不太理解为什么token的位置编码仅针对单时间步内不同特征进行,而并没有对完整episode上的时间步信息进行位置编码,个人觉得时间步信息也很重要(可能涉及到相对or绝对信息,例如走迷宫可能绝对时间步信息更重要,但对于自动驾驶而言可能相对时间步信息更重要)
  • 还有一个亮点是将动作和观测进行了一个统一,通常在进行特征设计时我们也会引入历史动作信息,Gato则是直接将动作也进行Tokenization,可以作为后续的动作的观测理解,也可以完美适配语言生成模型的Token预测,只需要进行对应的Detokenization就可以获取真实动作与环境交互。
  • Prompt Conditioning感觉像是一个trick,即将成功完成任务的示例作为模型的初始输入,所以泛化性实验中对比其他方法存在偏差。对于任务的表征方式很多,又例如one-hot编码、对任务进行自然语言描述并使用NLP预训练模型进行embedding,而本文提出的此类方法可能是比较适用于Transformer框架的。
  • 文章中并没也介绍文本、离散or连续值观测和动作的vector embedding space从何而来,是怎么训练获得的。
  • Gato仅使用监督学习进行训练,虽然文章表示理论上也可以应用于online RL,但二者的训练成本差距较大,监督学习并不需要考虑对于环境的探索。而多任务RL中,不同任务的探索方向上也存在冲突,怎样应用于大规模在线RL训练还需要进一步研究。而且在线RL训练则不存在Prompt Conditioning,因为该任务标识来源于专家数据。

原文链接:https://www.deepmind.com/publications/a-generalist-agent

参考资料:https://baijiahao.baidu.com/s?id=1732695949873185896&wfr=spider&for=pc & https://zhuanlan.zhihu.com/p/518115801

Last modification:July 20, 2022
如果觉得我的文章对你有用,请随意赞赏