Giter Club home page Giter Club logo

last's Introduction

Language-guided Skill Learning with Temporal Variational Inference (LAST)

Published at ICML 2024: Paper link Website link

We present an algorithm for skill discovery from expert demonstrations. The algorithm first utilizes Large Language Models (LLMs) to propose an initial segmentation of the trajectories. Following that, a hierarchical variational inference framework incorporates the LLM-generated segmentation information to discover reusable skills by merging trajectory segments. To further control the trade-off between compression and reusability, we introduce a novel auxiliary objective based on the Minimum Description Length principle that helps guide this skill discovery process.

TO DO: online hierarchical RL with the learned skills

Setup

Clone repo:

$ git clone https://github.com/Minusadd/LAST.git LAST

Install requirements:

$ virtualenv -p $(which python3.9) last
$ source last/bin/activate
$ cd LAST
$ pip install --upgrade pip
$ pip install -r requirements.txt

Downloading data and checkpoints

Install ALFRED and download the dataset.

$ git clone https://github.com/askforalfred/alfred.git alfred
$ cd alfred/data
$ sh download_data.sh json_feat
$ cd ../..

(Optional) Download the preprocessed features & LLM-generated data from a google drive.

LLM-generated initial segmentation

Setup openai api key

$ export OPENAI_API_KEY='your api key'
$ export OPENAI_API_BASE='your api base'
$ export OPENAI_API_TYPE='your api type'

Generate trajectory data using gpt-4

$ mkdir data_gpt4
$ python alfred_steps.py --data_dir ./alfred/data/json_2.1.0 --output_dir data_gpt4/ --n_workers 4

Note 1: The .jpeg images from the full dataset are different from the images rendered during evaluation due to the JPG compression. Thus we generated images for all the trajectories on our own. We are still trying to figure out how to share this but you can generate it on your own with the code provided in ET.

Note 2: You can directly download the gpt4-generated dataset we used from the google drive and skip this step.

Preprocess the data

Process the image and language data given the initial segmentation results

$ python process_data.py

Note: You can directly download the processed data we used from the google drive and skip this step: FasterRCNN , MaskRCNN , Image features , Language features , Goal features , Masks , Action sequences , Switching points , Processed trajectory data (gpt4) You will need to put all the downloaded files into the data/ folder.

$ mkdir data

Skill discovery with temporal variational inference

Train a LAST agent:

$ python algorithm.py --name train_last --train 1 --include_goal 1 --ent_weight 0.1 --kl_weight 0.0001

Evaluate the agent on the dataset:

First, download our checkpoint and put it into saved_nets/, then,

$ python algorithm.py --name test_last --train 0 --include_goal 1 --ent_weight 0.1 --kl_weight 0.0001 --model saved_nets/Model_epoch70

Citation

If you find this repository useful, please cite our work:

@inproceedings{fu2024languageskill,
  title     = {Language-guided Skill Learning with Temporal Variational Inference},
  author    = {Haotian Fu and Pratyusha Sharma and Elias Stengel-Eskin and George Konidaris and Nicolas Le Roux and Marc-Alexandre Côté and Xingdi Yuan},
  booktitle = {ICML},
  year      = {2024},
}

last's People

Contributors

minusadd avatar marccote avatar

Stargazers

Zheng KEKE avatar Grigorii Guz avatar H.B. Jiang avatar  avatar Xingdi (Eric) Yuan avatar

Watchers

Xingdi (Eric) Yuan avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.