Giter Club home page Giter Club logo

vita-group / diffses Goto Github PK

View Code? Open in Web Editor NEW
13.0 3.0 1.0 68.67 MB

๐Ÿ“œ [Under Review] "Symbolic Visual Reinforcement Learning: A Scalable Framework with Object-Level Abstraction and Differentiable Expression Search", Wenqing Zheng*, S P Sharan*, Zhiwen Fan, Kevin Wang, Yihan Xi, Atlas Wang

Home Page: https://vita-group.github.io/DiffSES/

Python 100.00%
interpretable-machine-learning neurosymbolic reinforcement-learning symbolic-regression

diffses's Introduction

Symbolic Visual Reinforcement Learning: A Scalable Framework with Object-Level Abstraction and Differentiable Expression Search

Wenqing Zheng*, S P Sharan*, Zhiwen Fan, Kevin Wang, Yihan Xi, Atlas Wang

Website Arxiv

Introduction

Learning efficient and interpretable policies has been a challenging task in reinforcement learning (RL), particularly in the visual RL setting with complex scenes. While deep neural networks have achieved competitive performance, the resulting policies are often over-parameterized black boxes that are difficult to interpret and deploy efficiently. More recent symbolic RL frameworks have shown that high-level domain-specific programming logic can be designed to handle both policy learning and symbolic planning. However, these approaches often rely on human-coded primitives with little feature learning, and when applied to high-dimensional continuous conversations such as visual scenes, they can suffer from scalability issues and perform poorly when images have complicated compositions and object interactions. To address these challenges, we propose Differentiable Symbolic Expression Search (DiffSES), a novel symbolic learning approach that discovers discrete symbolic policies using partially differentiable optimization. By using object-level abstractions instead of raw pixel-level inputs, DiffSES is able to leverage the simplicity and scalability advantages of symbolic expressions, while also incorporating the strengths of neural networks for feature learning and optimization. Our experiments demonstrate that DiffSES is able to generate symbolic policies that are more interpretable and scalable than state-of-the-art symbolic RL methods, even with a reduced amount of symbolic prior knowledge.

Inference procedure of the learned symbolic policy

Results

A subset of our trained environments

Here is the comparison of the models in a transfer learning setting. In this setting, the teacher DRL model is trained in AdventureIsland3, and the symbolic agent is learned based on it. Then both agents are applied to AdventureIsland2 without fine-tuning. The performance of the symbolic policy drops less than DRL model.

video-3-compressed.mp4
Symbolic policies enable ease of transfer owing to the disentanglement of control policies and feature extraction steps.
Visualization of a trained DiffSES policy

Usage

Stage I - Neural Policy Learning

Training a Visual RL agent as a teacher

  • Stable baselines 3 for PPO training

Run train_visual_rl.py with appropriate environment selected. Refer to stable baselines zoo for additional configuration options. Use the wrappers provided in retro_utils.py for running on retro environments with multiple lives and stages.

Trained model will be generated in the logs/ folder (along with tensorboard logs).

Stage II - Symbolic Fitting

Distillation of Teacher agent into Symbolic Student

  • GPLearn on offline dataset of teacher's actions

Part A: Training a self-supervised object detector

Training images for multiple atari environments can be found here. If you would like to run on custom/other environments, consider generating them using the provided script save_frames.py. We then proceed to train the OD module using these frames.

For more training parameters, consider referring the scripts and the SPACE project's documentation.

cd space/
python main.py --task train --config configs/atari_spaceinvaders.yaml resume True device 'cuda:0'

This should generate weights in the space/output/logs folder. Pretrained models from SPACE are available here.

Part B: Generating the offline dataset

Save teacher model's behavior (state-action pairs) along with OD module processing all such states. This creates a JSON of the form. sample.json contains a dummy dataset for demonstration purposes.

[
  {
    "state": [
      {
        "type": int,
        "x_velocity": float,
        "y_velocity": float,
        "x_position": int,
        "y_position": int
      },
      {
        "type": int,
        "x_velocity": float,
        ...
      }
      ...
    ],
    "teacher_action": int
  }
  ...
]

Part C: Symbolic distillation

We use gplearn's symbolic regression API in distill_teacher.py to train a symbolic tree to mimic the teacher's actions. The operators are as defined in the file and can easily be extended for more operands through the simple gplearn APIs. Please check see/judges.py for a few sample implementations of operators. The operands are the states from JSON as stored. We recommend running this experiment numerous times to achieve good performance as convergence of such a random search is not a guarantee every time. Please refer to gplearn_optuna.py for a sample of automating such a search on random data.

Stage III - Fine-tuning Symbolic Tree

Neural Guided Differentiable Search

Lastly, our symbolic finetuning stage consists of symbolic_finetuning.py which uses a custom implementation of gplearn modified in order to support the following:

  • RL style training: rewards as a fitness metric rather than MSE with respect to teacher behavior.
  • Differentiable constant optimization: new mutation scheme where the constants are set to be differentiable, the tree acts as the policy network for a PPO agent and optimization is performed on those constants.
  • Soft expert supervision in loss: add-on to earlier bullet along with an extra loss term to aforementioned loss being the difference between the teacher's action and the symbolic tree's prediction.

While running that file, please run a pip install -e . inside the custom implementation of gplearn to install the local version instead of the prebuilt wheels from PyPi. Similar to Part 2.3, we recommend running this experiment numerous times to achieve acceptable levels of convergence.

Citation

If you find our code implementation helpful for your own research or work, please cite our paper.

@article{zheng2022symbolic,
  title={Symbolic Visual Reinforcement Learning: A Scalable Framework with Object-Level Abstraction and Differentiable Expression Search},
  author={Zheng, Wenqing and Sharan, SP and Fan, Zhiwen and Wang, Kevin and Xi, Yihan and Wang, Zhangyang},
  journal={arXiv preprint arXiv:2212.14849},
  year={2022}
}

Contact

For any queries, please raise an issue or contact Wenqing Zheng.

License

This project is open sourced under MIT License.

diffses's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

jackzhousz

diffses's Issues

AttributeError: module 'numpy' has no attribute 'int'.

when I run gplearn_optuna.py,then
AttributeError: module 'numpy' has no attribute 'int'.
np.int was a deprecated alias for the builtin int. To avoid this error in existing code, use int by itself. Doing this will not modify any behavior and is safe. When replacing np.int, you may wish to use e.g. np.int64 or np.int32 to specify the precision. If you wish to review your current use, check the release note link for additional information.
The aliases was originally deprecated in NumPy 1.20; for more details and guidance see the original release note at:
https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations

Missing code for extracting position and velocity using OD module

This is a wonderful code repository with many impressive capabilities. I sincerely appreciate you sharing it.

When testing the code, I noticed the section for utilizing the OD module to extract velocity and position in[save_offline_dataset.py]&[symbolic_finetuning_guided.py] does not seem to be included, so do as the "save teacher dataset" code. If it's not too much trouble, adding that code would significantly improve the functionality. Please let me know if you would be willing to add it at your convenience.

Thank you again for providing this great resource. Your work is incredibly valuable to the community. I look forward to seeing future updates!

No module named 'space'

when I use :python main.py --task train --config configs/atari_spaceinvaders.yaml resume True device 'cuda:0'
then the following error happened:
Traceback (most recent call last):
File "main.py", line 1, in
from engine.utils import get_config
File "/home/l/Downloads/DiffSES/space/engine/utils.py", line 4, in
from space.config import cfg
ModuleNotFoundError: No module named 'space'

ValueError: supplied function logical does not return a numpy array.

when I run distill_teacher.py, then
Traceback (most recent call last):
File "/home/l/Downloads/DiffSES/distill_teacher.py", line 63, in
logical = make_function(function=_logical, name="logical", arity=2)
File "/home/l/Downloads/DiffSES/gplearn/functions.py", line 99, in make_function
raise ValueError("supplied function %s does not return a numpy array." % name)
ValueError: supplied function logical does not return a numpy array.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.