Giter Club home page Giter Club logo

rl_project's Introduction

RL Project

Reinforcement Learning Term Project for Fall Semester 2022

This is our improved version of the paper Deep Reinforcement Learning for Imbalanced Classification.

The original Github is here.

We aimed to extend the original method to a Korean multi-class text classification task using a pre-trained BERT model.

We used the KLAID dataset, which is a Korean legal document dataset.

Differences from the original paper

  • Implemented the original paper's method using PyTorch Lightning, instead of Keras-RL
  • Converted to a multi-class classification task (10 classes), instead of a binary classification task
  • Used a Korean dataset, instead of the original English dataset
  • Used a pre-trained encoder, instead of an original network
  • Applied not only DQN, but also Double DQN, Dueling DQN, and Policy Gradient

Installation

pip install -r requirements.txt

Training

  • run vanilla DQN
mkdir checkpoint
python train.py
  • run Double DQN
python train.py --config-name double_dqn
  • run Dueling DQN
python train.py --config-name dueling_dqn
  • run Policy Gradient
python policy_train.py --config-name policy_gradient

Commands for Experiment

  • gamma
python train.py --config-name dqn model.loss.gamma=0.1 name=DQN_base_0.1
python train.py --config-name dqn model.loss.gamma=0.5 name=DQN_base_0.5
python train.py --config-name dqn model.loss.gamma=0.9 name=DQN_base_0.9
  • Smooth L1 Loss
python train.py --config-name dqn model.loss.criterion._target_=torch.nn.SmoothL1Loss name=DQN_smooth
python train.py --config-name double_dqn model.loss.criterion._target_=torch.nn.SmoothL1Loss name=double_smooth
python train.py --config-name dueling_dqn model.loss.criterion._target_=torch.nn.SmoothL1Loss name=dueling_smooth
  • Random Seed
python train.py --config-name dqn seed=1111 name=dqn_1111
python train.py --config-name dqn seed=2222 name=dqn_2222
python train.py --config-name dqn seed=3333 name=dqn_3333
python train.py --config-name dqn seed=4444 name=dqn_4444
python train.py --config-name dqn seed=5555 name=dqn_5555
python train.py --config-name dqn seed=666 name=dqn_666
python train.py --config-name dqn seed=777 name=dqn_777 
python train.py --config-name dqn seed=6203 name=dqn_6203 
python train.py --config-name dqn seed=3040 name=dqn_3040 
python train.py --config-name dqn seed=6427 name=dqn_6427

Inference

If you want to test the model, download & unzip the checkpoint files from Link

  • infer Vanilla DQN
python infer.py --config-name infer_dqn checkpoint_path={checkpoint_path}
  • infer Double DQN
python infer.py --config-name infer_double checkpoint_path={checkpoint_path}
  • infer Dueling DQN
python infer.py --config-name infer_dueling checkpoint_path={checkpoint_path}
  • infer Policy Gradient
python infer.py --config-name infer_policy checkpoint_path={checkpoint_path}

References

Pytorch-Lightning Tutorial For DQN

rl_project's People

Contributors

jh-debug avatar lakahaga avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.