Giter Club home page Giter Club logo

q-transformer's Introduction

Q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind

I will be keeping around the logic for Q-learning on single action just for final comparison with the proposed autoregressive Q-learning on multiple actions. Also to serve as education for myself and the public.

Install

$ pip install q-transformer

Usage

import torch

from q_transformer import (
    QRoboticTransformer,
    QLearner,
    Agent,
    ReplayMemoryDataset
)

# the attention model

model = QRoboticTransformer(
    vit = dict(
        num_classes = 1000,
        dim_conv_stem = 64,
        dim = 64,
        dim_head = 64,
        depth = (2, 2, 5, 2),
        window_size = 7,
        mbconv_expansion_rate = 4,
        mbconv_shrinkage_rate = 0.25,
        dropout = 0.1
    ),
    num_actions = 8,
    action_bins = 256,
    depth = 1,
    heads = 8,
    dim_head = 64,
    cond_drop_prob = 0.2,
    dueling = True
)

# you need to supply your own environment, by overriding BaseEnvironment

from q_transformer.mocks import MockEnvironment

env = MockEnvironment(
    state_shape = (3, 6, 224, 224),
    text_embed_shape = (768,)
)

# env.init()     should return instructions and initial state: Tuple[str, Tensor[*state_shape]]
# env(actions)   should return rewards, next state, and done flag: Tuple[Tensor[()], Tensor[*state_shape], Tensor[()]]

# agent is a class that allows the q-model to interact with the environment to generate a replay memory dataset for learning

agent = Agent(
    model,
    environment = env,
    num_episodes = 1000,
    max_num_steps_per_episode = 100,
)

agent()

# Q learning on the replay memory dataset on the model

q_learner = QLearner(
    model,
    dataset = ReplayMemoryDataset(),
    num_train_steps = 10000,
    learning_rate = 3e-4,
    batch_size = 4,
    grad_accum_every = 16,
)

q_learner()

# after much learning
# your robot should be better at selecting optimal actions

video = torch.randn(2, 3, 6, 224, 224)

instructions = [
    'bring me that apple sitting on the table',
    'please pass the butter'
]

actions = model.get_optimal_actions(video, instructions)

Appreciation

Todo

  • first work way towards single action support

  • offer batchnorm-less variant of maxvit, as done in SOTA weather model metnet3

  • add optional deep dueling architecture

  • add n-step Q learning

  • build the conservative regularization

  • build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last)

  • improvise decoder head variant, instead of concatenating previous actions at the frames + learned tokens stage. in other words, use classic encoder - decoder

    • allow for cross attention to fine frame / learned tokens
  • redo maxvit with axial rotary embeddings + sigmoid gating for attending to nothing. enable flash attention for maxvit with this change

  • build out a simple dataset creator class, taking in the environment and model and returning a folder that can be accepted by a ReplayDataset

    • finish basic environment loop
    • store memories to memmapped files in designated folder
    • ReplayDataset that takes in folder
      • 1 time step option
      • n-time steps
  • handle multiple instructions correctly

  • show a simple end-to-end example, in the same style as all other repos

  • handle no instructions, leverage null conditioner in CFG library

  • cache kv for action decoding

  • for exploration, allow for finely randomizing a subset of actions, and not all actions at once

    • also allow for gumbel based sampling of actions, with annealing of gumbel noise
  • consult some RL experts and figure out if there are any new headways into resolving delusional bias

  • figure out if one can train with randomized orders of actions - order could be sent as a conditioning that is concatted or summed before attention layers

    • offer an improvised variant where the first action token suggests the action ordering. all actions aren't made equal, and some may need to attend to past actions more than others
  • simple beam search function for optimal actions

  • improvise cross attention to past actions and states of timestep, transformer-xl fashion (w/ structured memory dropout)

  • see if the main idea in this paper is applicable to language models here

Citations

@inproceedings{qtransformer,
    title   = {Q-Transformer: Scalable Offline Reinforcement Learning via Autoregressive Q-Functions},
    authors = {Yevgen Chebotar and Quan Vuong and Alex Irpan and Karol Hausman and Fei Xia and Yao Lu and Aviral Kumar and Tianhe Yu and Alexander Herzog and Karl Pertsch and Keerthana Gopalakrishnan and Julian Ibarz and Ofir Nachum and Sumedh Sontakke and Grecia Salazar and Huong T Tran and Jodilyn Peralta and Clayton Tan and Deeksha Manjunath and Jaspiar Singht and Brianna Zitkovich and Tomas Jackson and Kanishka Rao and Chelsea Finn and Sergey Levine},
    booktitle = {7th Annual Conference on Robot Learning},
    year   = {2023}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}

q-transformer's People

Contributors

francqz31 avatar lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

q-transformer's Issues

memmap can only handle max 2GB on certain systems

When I run usage example code with the latest release of q-transformer pip package, I get following error:

python example1.py 
using memory efficient attention
/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
Traceback (most recent call last):
  File "/home/ram/github/q-transformer/example1.py", line 47, in <module>
    agent = Agent(
  File "<@beartype(q_transformer.agent.Agent.__init__) at 0x7f21fa1563a0>", line 145, in __init__
  File "/home/ram/github/q-transformer/q_transformer/agent.py", line 208, in __init__
    self.states      = open_memmap(str(states_path), dtype = 'float32', mode = 'w+', shape = (*prec_shape, *state_shape))
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/numpy/lib/format.py", line 945, in open_memmap
    marray = numpy.memmap(filename, dtype=dtype, shape=shape, order=order,
  File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/numpy/core/memmap.py", line 254, in __new__
    fid.seek(bytes - 1, 0)
OSError: [Errno 22] Invalid argument

pip show q-transformer
Name: q-transformer
Version: 0.1.8
Summary: Q-Transformer
Home-page: https://github.com/lucidrains/q-transformer
Author: Phil Wang
Author-email: [email protected]
License: MIT
Location: /home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages
Requires: accelerate, beartype, classifier-free-guidance-pytorch, einops, ema-pytorch, numpy, torch, torchtyping
Required-by: 

A simple question about the code

Hi, @lucidrains, I'm a beginner trying to use Q-transformer and encountered a question while reading the code. In the QHeadMultipleActions class, I noticed that Q-transformer encodes the bin into an embedding using self.action_bin_embeddings. However, when obtaining the q value, it multiplies the attention output with self.action_bin_embeddings once again. Is there a specific reason for using this approach to derive the q value instead of employing a new MLP layer multiplied by the attention output? I've shared the relevant code below. Thank you!

def maybe_append_actions(self, sos_tokens, actions: Optional[Tensor] = None):
        if not exists(actions):
            return sos_tokens

        batch, num_actions = actions.shape
        action_embeddings = self.action_bin_embeddings[:num_actions]

        action_embeddings = repeat(action_embeddings, 'n a d -> b n a d', b = batch)
        past_action_bins = repeat(actions, 'b n -> b n 1 d', d = action_embeddings.shape[-1])

        bin_embeddings = action_embeddings.gather(-2, past_action_bins)
        bin_embeddings = rearrange(bin_embeddings, 'b n 1 d -> b n d')

        tokens, _ = pack((sos_tokens, bin_embeddings), 'b * d')
        tokens = tokens[:, :self.num_actions] # last action bin not needed for the proposed q-learning
        return tokens

def get_q_values(self, embed):
        num_actions = embed.shape[-2]
        action_bin_embeddings = self.action_bin_embeddings[:num_actions]

        if self.dueling:
            advantages = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings)

            values = einsum('b n d, n d -> b n', embed, self.to_values[:num_actions])
            values = rearrange(values, 'b n -> b n 1')

            q_values = values + (advantages - reduce(advantages, '... a -> ... 1', 'mean'))
        else:
            q_values = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings)

        return q_values.sigmoid()

Running the latest main branch with given usage example

Running the latest main branch with given usage example, results in:

episode 0
99%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Ž| 99/100 [01:13<00:00, 1.35it/s]
episode 1
0%| | 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/ram/github/q-transformer/example2.py", line 54, in
agent()
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/ram/github/q-transformer/q_transformer/agent.py", line 255, in forward
self.text_embeds[episode, step] = text_embed
File "/home/ram/anaconda3/envs/q-transformer/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1695, in getattr
raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'")
AttributeError: 'Agent' object has no attribute 'text_embeds'

question about Q-head

Hi,
Thank you for the code, really nice work.

I am new to the transformer architecture, and I am taking this code as a guidance to implement a simple q-transformer that works with a single task (i.e. not language conditioned) using states as observations and not images.

So, I think the "QHeadMultipleActions" class is only needed in my case. However, Do I still need the cross attention layer? There is no language or images in my case.

Thank you

The rest part of the code?

Hi, is this the official implementation of the paper 'Q-Transformer: Scalable Offline Reinforcement
Learning via Autoregressive Q-Functions' ? Could you please upload the rest part of the code? I really appreciate the idea in the paper, I hope to reproduce it recently.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.