Giter Club home page Giter Club logo

rl4lms's Introduction

🤖 RL4LMs 🚀

A modular RL library to fine-tune language models to human preferences


We provide easily customizable building blocks for training language models including implementations of on-policy algorithms, reward functions, metrics, datasets and LM based actor-critic policies

Paper Link: https://arxiv.org/abs/2210.01241

Website Link: https://rl4lms.apps.allenai.org/

Thoroughly tested and benchmarked with over 2000 experiments 🔥 (GRUE benchmark 🏆) on a comprehensive set of:

  • 7 different Natural Language Processing (NLP) Tasks:
    • Summarization
    • Generative Commonsense Reasoning
    • IMDB Sentiment-based Text Continuation
    • Table-to-text generation
    • Abstractive Question Answering
    • Machine Translation
    • Dialogue Generation
  • Different types of NLG metrics (20+) which can be used as reward functions:
    • Lexical Metrics (eg: ROUGE, BLEU, SacreBLEU, METEOR)
    • Semantic Metrics (eg: BERTSCORE, BLEURT)
    • Task specific metrics (eg: PARENT, CIDER, SPICE)
    • Scores from pre-trained classifiers (eg: Sentiment scores)
  • On-policy algorithms of PPO, A2C, TRPO and novel NLPO (Natural Language Policy Optimization)
  • Actor-Critic Policies supporting causal LMs (eg. GPT-2/3) and seq2seq LMs (eg. T5, BART)

All of these building blocks can be customizable allowing users to train transformer-based LMs to optimize any arbitrary reward function on any dataset of their choice.

Recent updates (v0.2.0) on 23-Nov-22

  • Added daily dialog task
  • Fixed compatibility issues with some Seq2seq models such as BART, blendorbot etc
  • Implemented data parallel support
  • Refactored policy classes

Recent updates (v0.2.1)

  • Minor logging updates

Install

Local Installation

git clone https://github.com/allenai/RL4LMs.git
cd RL4LMs
pip install -e .

Docker

We provide also a Dockerfile for development using docker containers containing all the dependencies.

docker build . -t rl4lms

Additional dependencies

Optionally, coreNLP libraries are required for certain metric computations (eg. SPICE) which can be downloaded through cd rl4lms/envs/text_generation/caption_metrics/spice && bash get_stanford_models.sh


Quick Start - Train PPO/NLPO using pre-defined YAML configs

We provide a simple training API that can be invoked via train script that allows to train PPO, NLPO or a supervised model by using a config file (YAML).

For example, to train T5-base on CNN/DM summarization on PPO using Rouge-1 as reward function, you can run:

python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/summarization/t5_ppo.yml

Config files for all tasks can be found here.

YAML file schema - Configuring building blocks

Config file contains details about hyper-parameter settings for building blocks which are described below:

  • Dataset/Task: Dataset containing samples with input prompts and reference sentences. Available datasets are found in the class DataPoolRegistry in registry. (See how to create your own dataset here)

    datapool:
      id: cnn_daily_mail
      args:
        prompt_prefix: "Summarize: "
  • Tokenizer - A pre-trained tokenizer that is used to (de)tokenize input and output sequences with settings for padding and truncation

    tokenizer:
      model_name: t5-base
      padding_side: left
      truncation_side: left
      pad_token_as_eos_token: False
  • Reward Function: Reward function which computes token-level scores at each time step of MDP. Available reward functions can be found in the class RewardFunctionRegistry. (See how to create your own reward function here)

    reward_fn:
      id: rouge
      args:
        rouge_type: "rouge1"
  • Environment: Configures a gym-style text generation environment which simulates MDP episodes. Rollouts are generated using train samples from dataset consisting of input and reference texts. Further, we wrap our env with SubProcVecEnv from stable-baselines that processes n_envs episodes in parallel using multi-processing to compute step-wise rewards.
    Further configuration settings include:

    • max_episode_length : max length of the episode
    • max_prompt_length - maximum length of the input text to consider
    • terminate_on_eos - whether to terminate the episode as soon as EOS action is performed
    • prompt_truncation_side - truncation side for the prompt text
    • context_start_token - id for context token (corresponds to initial token given to decoder in encoder-decoder models)
    env:
      n_envs: 10
      args:
        max_prompt_length: 512
        max_episode_length: 100
        terminate_on_eos: True
        prompt_truncation_side: "right"
        context_start_token: 0
  • On-policy alg: We provide implementations of 4 on-policy algorithms: PPO, NLPO, A2C and TRPO adapted from stable-baselines3 tailored to work with NLP tasks which can be used out-of-the-box with either a causal policy or a seq2seq LM policy. (See how to create your own on-policy algorithm or policy)

    • We also provide a supervised trainer for benchmarking purposes. Supervised Warm start models are already uploaded to Huggingface Hub and specified in the respective config files.

    • Hyper-parameters for the algorithm can be specified at alg/args.

    • Further, all RL algorithms use adaptive KL controller to keep the LM close to original LM by setting initial KL co-efficient (alg/kl_div/coeff) and target KL (alg/kl_div/target_kl).

    • We support two types of LM policy: causal LM policy (for decoder only models) and seq2seq LM policy (for encoder-decoder models). Further for NLPO, we also provide maskable variants of these. Policy implementations can be found here in and it can be attached to algorithms by specifying alg/policy/id and alg/policy/args

      alg:
        id: ppo
        args: 
          n_steps: 512
          batch_size: 64
          verbose: 1
          learning_rate: 0.000002
          n_epochs: 5
          ent_coef: 0.0
        kl_div:
          coeff: 0.001
          target_kl: 0.2
        policy:
          id: seq2seq_lm_actor_critic_policy
          args:
            model_name: t5-base
            apply_model_parallel: True
            prompt_truncation_side: "right"
            generation_kwargs:
              do_sample: True
              top_k: 50
              min_length: 50
              max_new_tokens: 100          
  • Trainer Config: We provide an On-policy trainer - a feature-complete wrapper that instantiates building blocks from their corresponding configs and provides an outer training loop consisting of train and eval iterations train_evaluation/n_iters.

    • Each iteration corresponds to performing updates with alg/args/n_steps x env/n_envs of the chosen algorithm.
    • For every eval_every iters, LM is evaluated on validation split using metrics listed in train_evaluation/metrics with generation kwargs provided in train_evaluation/generation_kwargs (this overrides rollout alg/policy/generation_kwargs for inference purposes only)
    # train and evaluation
    train_evaluation:
      eval_batch_size: 100
      n_iters: 100
      eval_every: 10
      save_every: 1
      metrics:
        - id: meteor
          args: {}
        - id: rouge
        - id: bleu
          args: {}
        - id: bert_score
          args:
            language: en
        - id: diversity
          args: {}
      generation_kwargs: 
        do_sample: True
        top_k: 0
        temperature: 0.7
        min_length: 50
        max_new_tokens: 100

Custom Building Blocks 🔧

RL4LMs provide complete customizability - with respect to adding new tasks/datasets, reward functions, evaluation metric, on-policy algorithms and actor-critic policies.

Adding dataset

Users can create their own datasets by sub-classing TextGenPool just by overriding prepare(cls, split: str, **args) -> 'TextGenPool': method to return an instance of TextGenPool. An example is shown below:

from rl4lms.data_pools.text_generation_pool import Sample, TextGenPool

class MyDataPool(TextGenPool):
   @classmethod
   def prepare(cls, split: str):
       .. 
       samples = []
       for ix, item in enumerate(..):
           sample = Sample(id=f"{split}_{ix}",
                           prompt_or_input_text=item["document"],
                           references=[item["target"]]
                           )
           samples.append(sample)
       pool_instance = cls(samples)
       return pool_instance

Adding reward function

Custom reward funtions can be implemented easily by sub-classing RewardFunction (a callable) which takes observation ($s$), next observation ($s'$), action ($a$), done (indicating whether episode is finished) and meta info (containing other information about textual input). Here, Observation is a data class object consisting of generated text (at a particular step), prompt text, context text (at that step), reference text which can be used to compute token-level or sentence level rewards.

from rl4lms.envs.text_generation.observation import Observation
from rl4lms.envs.text_generation.reward import RewardFunction


class MyRewardFunction(RewardFunction):
   def __init__(self, *args) -> None:
       super().__init__()

   def __call__(self, prev_observation: Observation,
                action: int,
                current_observation: Observation,
                done: bool,
                meta_info: Dict[str, Any] = None) -> float:
       if done:
           reward = ..
           return reward
       return 0

💡 In addition to traditional NLG metrics, for quick prototyping, we provide two synthetic reward functions which trains LMs to generate numbers in increasing order and generate dates. These can be used to quickly test different algorithms and policies. Corresponding configs can be found here (numbers, dates)

Adding custom metrics

Users can create their own evaluation metric which then will be used to periodically evaluate the model on validation split of dataset. This can be done by sub-classing BaseMetric which takes prompt texts, generated texts, reference texts, meta_infos, current LM model, split name as inputs and returns a dict with metric name as key and value consisting of tuple of sentence-level scores and corpus level scores. An example is as follows:

from rl4lms.envs.text_generation.metric import BaseMetric

class MyMetric(BaseMetric):
   def __init__(self) -> None:
       super().__init__()

   def compute(self,
               prompt_texts: List[str],
               generated_texts: List[str],
               reference_texts: List[List[str]],
               meta_infos: List[Dict[str, Any]] = None,
               model: PreTrainedModel = None,
               split_name: str = None):
       metric_dict = {
           "custom_metrics/my_metric": ([0.4, 0.7, 0.9], 0.7)
       }
       return metric_dict

Adding custom on-policy algorithms

In addition to supported on-policy algorithms (PPO, NLPO, A2C,TRPO), users can implement their own on-policy algorithms with ease by sub-classing stable-baselines3's OnPolicyAlgorithm. Since we provide wrappers for on-policy algorithms that handles rollouts using LM policies, environment, computing rewards etc, users just need to implement train() method with custom loss functions.

from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm

class MyOnPolicyAlgorithm(OnPolicyAlgorithm):
    def __init__(**args):
        super().__init__(**args)

    def train(self) -> None:
        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
              # compute loss

Adding custom policies

We provide LM based actor-critic policy implementations that wraps causal LM and seq2seq LMs. These can be also extended (for eg: use a different critic architecture) by overriding appropriate methods (eg. evaluate_actions())

Registry

Finally, just register your custom components by adding them to corresponding registry, after which they can be used directly from configs similar to pre-defined components 👋

Crowdsourcing templates

We have provided the crowdsourcing templates we used on mechanical turk, along with example inputs in scripts/crowdworking_templates. You might find these a helpful starting point either for evaluating your own model's generations, or for gathering training data for a learned reward function.


Logging and Experiment Results

Additionally, we support WANDB logging and warm-starting of training by storing checkpoints and other training artifacts in a user-specified path. This is especially useful for running preemptible jobs on large, scheduled clusters.

Artifacts include (1) jsonl file containing rollout infos at specified intervals (2) jsonl file containing training infos at specified intervals (3) jsonl file containing validation metrics at specified intervals (4) jsonl file containing test metrics before and after training (5) json file with validation predictions at specified intervals (6) json file with test predictions before and after training (7) trained LM model (8) config json used to run the experiment

Complete usage is as follows:

WANDB_API_KEY=<YOUR-WANDB-API-KEY-HERE>  python scripts/training/train_text_generation.py \
--config_path <PATH-TO-CONFIG-FILE> \
--experiment_name <EXPERIMENT-NAME> \
--base_path_to_store_results <PATH-TO-STORE-RESULTS> \
--log_to_wandb

Citation

@inproceedings{Ramamurthy2022IsRL,
  title={Is Reinforcement Learning (Not) for Natural Language Processing?: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization},
  author={Rajkumar Ramamurthy and Prithviraj Ammanabrolu and Kiant{\'e} Brantley and Jack Hessel and Rafet Sifa and Christian Bauckhage and Hannaneh Hajishirzi and Yejin Choi},
  journal={arXiv preprint arXiv:2210.01241},
  url={https://arxiv.org/abs/2210.01241},
  year={2022}
}

Questions/Discussion/Ideas?

For discussion, questions, ideas exchange, join our slack channel Slack

rl4lms's People

Contributors

akifumi-wachi-4 avatar creativeai-ws avatar jmhessel avatar julesgm avatar rajammanabrolu avatar rajcscw avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rl4lms's Issues

Off-policy RL algorithms support

Hi, first of all, great work. This is a very useful library for research on RL and NLP. It will be very helpful if it's possible to add off-policy RL methods like Q-learning, SAC, etc. along with benchmarks.

Also, new offline RL methods applied to NLP like ILQL can be very interesting for human alignment, and support for such methods will further enhance the value of this codebase.

Reproducing IMDB results

Hi, I'm currently running the imdb experiments and trying to reproduce the PPO and NLPO results from the paper and though my PPO is close, NLPO is quite far from reported. Do you have any advice for reproducing NLPO results?

I'm running the default config (scripts/training/task_configs/imdb_text_continuation/gpt2_{ppo,nlpo}.yml) and the final test results compared to the results from the paper are below.

Sentiment Score Fluency (Perplexity)
zero-shot (ppo) 0.486 32.4
ppo 0.604 33.0
zero-shot (nlpo) 0.497 32.7
nlpo 0.496 40.8
paper's zero-shot 0.489 32.2
paper's ppo 0.605 33.5
paper's nlpo 0.637 32.7

PPO results are similar and even slightly lower ppl but NLPO is not at all close. Here are the validation curves
image

NLPO also seems to improve in sentiment for a bit and then suddenly stops and decreases but all the while the perplexity is going up. Comparing the training curves, it seems that approx KL loss is much larger for NLPO but this could be reasonable given the changes in NLPO. Do you see similar curves?

image

Finally, in the paper's Appendix Table 4 it says that it runs for 10 epochs but in Figure 4 just below (and also based on the wandb logging) these experiments are for 50 epochs. Should I be running for 10 or 50 epochs?

Each experiment is being run on 4 A100 GPUs as per #12

model.generate.scores returning two scores

Dear contributors,

Thank you so much! This repo is excellent!

What is the difference between raw_logits, and processed_logits?

How does it differ from the normal hugging face model.generate.score?

Thank you,
Debjit

Using GPT-2

In the README, it is mentioned that Actor-Critic Policies supporting causal LMs (eg. GPT-2/3) and seq2seq LMs (eg. T5, BART). I was wondering how I can use GPT-2 model? I see the following call stack on instantiating and loading the model by default instantiates a seq2seq (encoder-decoder) model. Now, the question is how can I switch to a regular GPT model.

train_text_generation.py:84
train_text_generation.py:46 -> main()
training_utils.py:149 -> __init__()
training_utils.py:164 -> _setup()
training_utils.py:118 -> build_alg()
alg_wrappers.py:407 -> wrap_onpolicy_alg()
alg_wrappers.py:108 -> __init__()
ppo.py:166 -> __init__()
ppo.py:169 -> _setup_model()
on_policy_algorithm.py:117 -> _setup_model()
seq2seq_policy.py:51 -> __init__()
base_policy.py:135 -> __init__()
seq2seq_policy.py:67 -> _build_model_heads() (It gets a deepcopy of _policy_model as _ref_model. Then this calls from_pretrained for _policy_model, _value_model, and _ref_model to load the model parameters)
auto_factory.py:446 -> from_pretrained()

'BartForConditionalGeneration' has no attribute 'encoder'

I'm getting this error when trying to finetune Bart using PPO. Is this because BART isn't fully implemented yet, or because I'm using a wrong model?

My yml looks like this: (it's the default ppo one with only the model changed)

tokenizer:
model_name: facebook/bart-large
padding_side: left
truncation_side: left
pad_token_as_eos_token: False

reward_fn:
id: rouge
args:
rouge_type: "rouge1"

datapool:
id: cnn_daily_mail
args:
prompt_prefix: "Summarize: "

env:
n_envs: 10
args:
max_prompt_length: 512
max_episode_length: 100
terminate_on_eos: True
prompt_truncation_side: "right"
context_start_token: 0

alg:
id: ppo
args:
n_steps: 512
batch_size: 64
verbose: 1
learning_rate: 0.000002
n_epochs: 5
ent_coef: 0.0
kl_div:
coeff: 0.001
target_kl: 0.2
policy:
id: seq2seq_lm_actor_critic_policy
args:
model_name: facebook/bart-large-cnn
apply_model_parallel: True
prompt_truncation_side: "right"
generation_kwargs:
do_sample: True
top_k: 50
min_length: 50
max_new_tokens: 100

train_evaluation:
eval_batch_size: 100
n_iters: 100
eval_every: 10
save_every: 1
metrics:
- id: meteor
args: {}
- id: rouge
- id: bleu
args: {}
- id: bert_score
args:
language: en
# - id: bleurt
# args:
# config_name: bleurt-large-512
- id: diversity
args: {}
# - id: summaCZS
# args:
# granularity: sentence
# use_ent: True
# use_con: False
# - id: summaCConv
# args:
# granularity: sentence
generation_kwargs:
do_sample: True
top_k: 0
temperature: 0.7
min_length: 50
max_new_tokens: 100

Some questions about n_steps,n_envs and padding_side.

Thanks for such great work!
I am familiar with NLP, but new to RL. What does n_steps mean? Or what does it control? In text generation, does it mean generate n_steps tokens?

For n_envs, why do i need more than 1 env? Does losses from different env will be averaged?

At last i saw that the update function in observation.py only considered left padding. I think it also need a right padding one?

Thanks again and sorry for so much questions.

Evaluating a specific checkpoint

hey,
first of all thank you very much for this amazing library!
I was using it to finetune a model, and I am interested in evaluating one of the saved checkpoints on my testset.
Is there an easy way to do it?
Thanks.

CPU Support Minor Bug

Hello, I believe I found a minor bug in IntentAccuracyDailyDialog, lines 672-3 in envs/text_generation/metric.py. The device is currently set with the following two lines:

self._device = "cuda" if torch.cuda.is_available() else "cpu"
self._device = f"cuda:{torch.cuda.device_count() - 1}"

However, I believe it should be such that the device is f"cuda:{torch.cuda.device_count() - 1}" if on GPU and "cpu" otherwise (as currently it can never be set to CPU, and instead tries to find "cuda:-1"). Thanks so much for putting this library together!

Problem with BLEURT reward function

BLEURT reward function fails with TypeError: cannot pickle '_thread.RLock' object in multiprocessing environments.
Probably because it can't pickle Tensorflow model to send to environment subprocess.

Tested on both local and colab environment.

Here is the full stacktrace:

│ /home/eublefar/RL4LMs/rl4lms/envs/text_generation/training_utils.py:149 in __init__       │
│                                                                                           │
│   146 │   │   self._train_eval_config = train_eval_config                                 │
│   147 │   │   self._tracker = tracker                                                     │
│   148 │   │   self._experiment_name = experiment_name                                     │
│ ❱ 149 │   │   self._setup()                                                               │
│   150 │                                                                                   │
│   151 │   def _setup(self):                                                               │
│   152 │   │   # load trainer state from available previous checkpoint if available        │
│                                                                                           │
│ /home/eublefar/RL4LMs/rl4lms/envs/text_generation/training_utils.py:162 in _setup         │
│                                                                                           │
│   159 │   │   │   self._train_eval_config.get("metrics", []))                             │
│   160 │   │   self._samples_by_split = build_datapool(                                    │
│   161 │   │   │   self._datapool_config)                                                  │
│ ❱ 162 │   │   self._env = build_env(self._env_config, self._reward_fn,                    │
│   163 │   │   │   │   │   │   │     self._tokenizer, self._samples_by_split["train"])     │
│   164 │   │   self._alg = build_alg(self._on_policy_alg_config,                           │
│   165 │   │   │   │   │   │   │     self._env, self._tracker,                             │
│                                                                                           │
│ /home/eublefar/RL4LMs/rl4lms/envs/text_generation/training_utils.py:90 in build_env       │
│                                                                                           │
│    87 │   │   "samples": train_samples,                                                   │
│    88 │   }                                                                               │
│    89 │   env_kwargs = {**env_kwargs, **env_config.get("args", {})}                       │
│ ❱  90 │   env = make_vec_env(TextGenEnv,                                                  │
│    91 │   │   │   │   │      n_envs=env_config.get(                                       │
│    92 │   │   │   │   │   │      "n_envs", 1),                                            │
│    93 │   │   │   │   │      vec_env_cls=SubprocVecEnv,                                   │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/stable_baselines3/common/e │
│ nv_util.py:105 in make_vec_env                                                            │
│                                                                                           │
│   102 │   │   # Default: use a DummyVecEnv                                                │
│   103 │   │   vec_env_cls = DummyVecEnv                                                   │
│   104 │                                                                                   │
│ ❱ 105 │   return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_en │
│   106                                                                                     │
│   107                                                                                     │
│   108 def make_atari_env(                                                                 │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/stable_baselines3/common/v │
│ ec_env/subproc_vec_env.py:106 in __init__                                                 │
│                                                                                           │
│   103 │   │   │   args = (work_remote, remote, CloudpickleWrapper(env_fn))                │
│   104 │   │   │   # daemon=True: if the main process crashes, we should not cause things  │
│   105 │   │   │   process = ctx.Process(target=_worker, args=args, daemon=True)  # pytype │
│ ❱ 106 │   │   │   process.start()                                                         │
│   107 │   │   │   self.processes.append(process)                                          │
│   108 │   │   │   work_remote.close()                                                     │
│   109                                                                                     │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/process.py:121 in start  │
│                                                                                           │
│   118 │   │   assert not _current_process._config.get('daemon'), \                        │
│   119 │   │   │      'daemonic processes are not allowed to have children'                │
│   120 │   │   _cleanup()                                                                  │
│ ❱ 121 │   │   self._popen = self._Popen(self)                                             │
│   122 │   │   self._sentinel = self._popen.sentinel                                       │
│   123 │   │   # Avoid a refcycle if the target function holds an indirect                 │
│   124 │   │   # reference to the process object (see bpo-30775)                           │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/context.py:291 in _Popen │
│                                                                                           │
│   288 │   │   @staticmethod                                                               │
│   289 │   │   def _Popen(process_obj):                                                    │
│   290 │   │   │   from .popen_forkserver import Popen                                     │
│ ❱ 291 │   │   │   return Popen(process_obj)                                               │
│   292 │                                                                                   │
│   293 │   class ForkContext(BaseContext):                                                 │
│   294 │   │   _name = 'fork'                                                              │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/popen_forkserver.py:35   │
│ in __init__                                                                               │
│                                                                                           │
│   32 │                                                                                    │
│   33 │   def __init__(self, process_obj):                                                 │
│   34 │   │   self._fds = []                                                               │
│ ❱ 35 │   │   super().__init__(process_obj)                                                │
│   36 │                                                                                    │
│   37 │   def duplicate_for_child(self, fd):                                               │
│   38 │   │   self._fds.append(fd)                                                         │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/popen_fork.py:19 in      │
│ __init__                                                                                  │
│                                                                                           │
│   16 │   │   util._flush_std_streams()                                                    │
│   17 │   │   self.returncode = None                                                       │
│   18 │   │   self.finalizer = None                                                        │
│ ❱ 19 │   │   self._launch(process_obj)                                                    │
│   20 │                                                                                    │
│   21 │   def duplicate_for_child(self, fd):                                               │
│   22 │   │   return fd                                                                    │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/popen_forkserver.py:47   │
│ in _launch                                                                                │
│                                                                                           │
│   44 │   │   set_spawning_popen(self)                                                     │
│   45 │   │   try:                                                                         │
│   46 │   │   │   reduction.dump(prep_data, buf)                                           │
│ ❱ 47 │   │   │   reduction.dump(process_obj, buf)                                         │
│   48 │   │   finally:                                                                     │
│   49 │   │   │   set_spawning_popen(None)                                                 │
│   50                                                                                      │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/reduction.py:60 in dump  │
│                                                                                           │
│    57                                                                                     │
│    58 def dump(obj, file, protocol=None):                                                 │
│    59 │   '''Replacement for pickle.dump() using ForkingPickler.'''                       │
│ ❱  60 │   ForkingPickler(file, protocol).dump(obj)                                        │
│    61                                                                                     │
│    62 #                                                                                   │
│    63 # Platform specific definitions                                                     │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/stable_baselines3/common/v │
│ ec_env/base_vec_env.py:371 in __getstate__                                                │
│                                                                                           │
│   368 │   │   self.var = var                                                              │
│   369 │                                                                                   │
│   370 │   def __getstate__(self) -> Any:                                                  │
│ ❱ 371 │   │   return cloudpickle.dumps(self.var)                                          │
│   372 │                                                                                   │
│   373 │   def __setstate__(self, var: Any) -> None:                                       │
│   374 │   │   self.var = cloudpickle.loads(var)                                           │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/cloudpickle/cloudpickle_fa │
│ st.py:73 in dumps                                                                         │
│                                                                                           │
│    70 │   │   │   cp = CloudPickler(                                                      │
│    71 │   │   │   │   file, protocol=protocol, buffer_callback=buffer_callback            │
│    72 │   │   │   )                                                                       │
│ ❱  73 │   │   │   cp.dump(obj)                                                            │
│    74 │   │   │   return file.getvalue()                                                  │
│    75                                                                                     │
│    76 else:                                                                               │
│                                                                                           │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/cloudpickle/cloudpickle_fa │
│ st.py:632 in dump                                                                         │
│                                                                                           │
│   629 │                                                                                   │
│   630 │   def dump(self, obj):                                                            │
│   631 │   │   try:                                                                        │
│ ❱ 632 │   │   │   return Pickler.dump(self, obj)                                          │
│   633 │   │   except RuntimeError as e:                                                   │
│   634 │   │   │   if "recursion" in e.args[0]:                                            │
│   635 │   │   │   │   msg = (                                                             │
╰───────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: cannot pickle '_thread.RLock' object

Is the construction of _value_model necessary?

Why do you need to define _value_model in the policy, I think you can use _ref_model plus _value_head to get the value, at least 1/3 of the parameters and backward gradient overhead are reduced in the GPU memory.

    def _build_model_heads(self,
                           model_name: str):
        self._policy_model = AutoModelForCausalLM.from_pretrained(
            model_name)
        self._policy_model.__class__ = override_generation_routines(
            type(self._policy_model))

        self._value_model = AutoModelForCausalLM.from_pretrained(
            model_name)
        self._ref_model = deepcopy(self._policy_model).eval()

        self._value_head = nn.Linear(
            self._value_model.config.hidden_size, 1, bias=False)

        # apply model parallel
        ...
        self._value_head = self._value_head.to(self.device)

`train` and `val` splits are not disjoint for IMDB

First, thanks for the great repo!

It seems that since the train and valid splits of IMDB are created with two separate calls to _get_datapool_by_split (which calls IMDB.prepare) and each call shuffles the data before sampling the split, the train and val splits will largely overlap (within each run). This seems highly problematic because results on the validation set will basically be invalid.

OOM on summarization example

Hi there, I'm having OOM errors when running the summarization example on a 80GB A100 (CUDA 11.8).

I'm also getting some Tensorflow/TensorRT warnings, I'm wondering if it's related to that

2022-11-08 22:44:46.878785: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-08 22:44:47.016183: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-08 22:44:47.979748: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-11-08 22:44:47.979824: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-11-08 22:44:47.979834: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

OOM error:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│                                                                              │
│ /mnt/home/code/RL4LMs/scripts/training/train_text_generation.py:66 in        │
│ <module>                                                                     │
│                                                                              │
│   63 │   │   │   │   │   │   help="Whether to use wandb logging")            │
│   64 │   args = parser.parse_args()                                          │
│   65 │                                                                       │
│ ❱ 66 │   main(args.config_path,                                              │
│   67 │   │    args.project_name,                                             │
│   68 │   │    args.experiment_name,                                          │
│   69 │   │    args.base_path_to_store_results,                               │
│ /mnt/home/code/RL4LMs/scripts/training/train_text_generation.py:42 in main   │
│                                                                              │
│   39 │   │   │   │   │   │   │   │     on_policy_alg_config=config["alg"],   │
│   40 │   │   │   │   │   │   │   │     train_eval_config=config["train_evalu │
│   41 │   │   │   │   │   │   │   │     tracker=tracker)                      │
│ ❱ 42 │   trainer.train_and_eval()                                            │
│   43                                                                         │
│   44                                                                         │
│   45 if __name__ == "__main__":                                              │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/training_utils.py:205 in   │
│ train_and_eval                                                               │
│                                                                              │
│   202 │   │   │   self._trainer_state["current_iter"] = epoch                │
│   203 │   │   │                                                              │
│   204 │   │   │   # inner rollout and learn loop for on-policy algorithm     │
│ ❱ 205 │   │   │   self._alg.learn(self._n_steps_per_iter)                    │
│   206 │   │   │                                                              │
│   207 │   │   │   # save the policy checkpoint                               │
│   208 │   │   │   if (epoch + 1) % self._train_eval_config.get("save_every", │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/algorithms/ppo/ppo.py:347 in learn              │
│                                                                              │
│   344 │   │   reset_num_timesteps: bool = True,                              │
│   345 │   ) -> "PPO":                                                        │
│   346 │   │                                                                  │
│ ❱ 347 │   │   return super().learn(                                          │
│   348 │   │   │   total_timesteps=total_timesteps,                           │
│   349 │   │   │   callback=callback,                                         │
│   350 │   │   │   log_interval=log_interval,                                 │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/stable_baselines3/common/on │
│ _policy_algorithm.py:267 in learn                                            │
│                                                                              │
│   264 │   │   │   │   self.logger.record("time/total_timesteps", self.num_ti │
│   265 │   │   │   │   self.logger.dump(step=self.num_timesteps)              │
│   266 │   │   │                                                              │
│ ❱ 267 │   │   │   self.train()                                               │
│   268 │   │                                                                  │
│   269 │   │   callback.on_training_end()                                     │
│   270                                                                        │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/algorithms/ppo/ppo.py:224 in train              │
│                                                                              │
│   221 │   │   │   │   if self.use_sde:                                       │
│   222 │   │   │   │   │   self.policy.reset_noise(self.batch_size)           │
│   223 │   │   │   │                                                          │
│ ❱ 224 │   │   │   │   values, log_prob, entropy = self.policy.evaluate_actio │
│   225 │   │   │   │   │   rollout_data.observations, actions)                │
│   226 │   │   │   │   values = values.flatten()                              │
│   227 │   │   │   │   # Normalize advantage                                  │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/policy.py:211 in           │
│ evaluate_actions                                                             │
│                                                                              │
│    208 │   │                                                                 │
│    209 │   │   _, log_prob, entropy, _, _ = self.forward_policy(obs=obs,     │
│    210 │   │   │   │   │   │   │   │   │   │   │   │   │   │    actions=acti │
│ ❱  211 │   │   values, _ = self.forward_value(obs)                           │
│    212 │   │                                                                 │
│    213 │   │   return values, log_prob, entropy                              │
│    214                                                                       │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/policy.py:447 in           │
│ forward_value                                                                │
│                                                                              │
│    444 │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │     │
│    445 │   │                                                                 │
│    446 │   │   # and forrward pass to get hidden states                      │
│ ❱  447 │   │   outputs = self._value_model(                                  │
│    448 │   │   │   **model_inputs,                                           │
│    449 │   │   │   output_hidden_states=True,                                │
│    450 │   │   │   decoder_attention_mask=decoder_attn_mask,                 │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:1648 in forward                                                   │
│                                                                              │
│   1645 │   │   │   │   decoder_attention_mask = decoder_attention_mask.to(se │
│   1646 │   │                                                                 │
│   1647 │   │   # Decode                                                      │
│ ❱ 1648 │   │   decoder_outputs = self.decoder(                               │
│   1649 │   │   │   input_ids=decoder_input_ids,                              │
│   1650 │   │   │   attention_mask=decoder_attention_mask,                    │
│   1651 │   │   │   inputs_embeds=decoder_inputs_embeds,                      │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:1040 in forward                                                   │
│                                                                              │
│   1037 │   │   │   │   │   None,  # past_key_value is always None with gradi │
│   1038 │   │   │   │   )                                                     │
│   1039 │   │   │   else:                                                     │
│ ❱ 1040 │   │   │   │   layer_outputs = layer_module(                         │
│   1041 │   │   │   │   │   hidden_states,                                    │
│   1042 │   │   │   │   │   attention_mask=extended_attention_mask,           │
│   1043 │   │   │   │   │   position_bias=position_bias,                      │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:699 in forward                                                    │
│                                                                              │
│    696 │   │   │   else:                                                     │
│    697 │   │   │   │   query_length = None                                   │
│    698 │   │   │                                                             │
│ ❱  699 │   │   │   cross_attention_outputs = self.layer[1](                  │
│    700 │   │   │   │   hidden_states,                                        │
│    701 │   │   │   │   key_value_states=encoder_hidden_states,               │
│    702 │   │   │   │   attention_mask=encoder_attention_mask,                │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:613 in forward                                                    │
│                                                                              │
│    610 │   │   output_attentions=False,                                      │
│    611 │   ):                                                                │
│    612 │   │   normed_hidden_states = self.layer_norm(hidden_states)         │
│ ❱  613 │   │   attention_output = self.EncDecAttention(                      │
│    614 │   │   │   normed_hidden_states,                                     │
│    615 │   │   │   mask=attention_mask,                                      │
│    616 │   │   │   key_value_states=key_value_states,                        │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:509 in forward                                                    │
│                                                                              │
│    506 │   │   )                                                             │
│    507 │   │                                                                 │
│    508 │   │   # compute scores                                              │
│ ❱  509 │   │   scores = torch.matmul(                                        │
│    510 │   │   │   query_states, key_states.transpose(3, 2)                  │
│    511 │   │   )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_stat │
│    512                                                                       │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: CUDA out of memory. Tried to allocate 150.00 MiB (GPU 0; 79.35 GiB
total capacity; 76.08 GiB already allocated; 108.19 MiB free; 77.20 GiB reserved
in total by PyTorch) If reserved memory is >> allocated memory try setting
max_split_size_mb to avoid fragmentation.  See documentation for Memory
Management and PYTORCH_CUDA_ALLOC_CONF

Any clues what's the issue? 80GB seems like a lot for just a T5-base model

Reproducing existing results on NarrativeQA

I'm trying to reproduce the results for NarrativeQA by directly running the command with the .yml configuration files. Below are the performances measured with ROUGE-L-Max.
For PPO with supervision, I got 0.581 and 0.588 for epochs 0 and 99, respectively.
For NLPO with supervision, I got 0.217 and 0.213 for epochs 0 and 99, respectively.

I'm wondering why the result for NLPO doesn't match the reported result in the paper.

I also tried to use the config for PPO, and just modify the RL algorithm to NLPO, I got the same result as above.

Please let me know if I'm missing something or if it's some other issue. Thanks!

Mix-Precision training

Hey,
Are there any plans to add support for mixed precision training?
I did see in #12 a temporary solution was suggested, but it still throws multiple exceptions relating to mathematical operations between fp16 and fp32 values.
Thanks!
@rajcscw

NLPO Code Error and Query About gymnasium vs gym Usage

I hope this message finds you well. I am writing to report an issue I encountered in the NLPO project that you maintain on GitHub.

While executing the following block of code, in lines 278 of nlpo.py:
if eval_env is not None and self.seed is not None:
eval_env.seed(self.seed)
eval_env = self._get_eval_env(eval_env)
if not self._custom_logger:

I got:eval_env = self._get_eval_env(eval_env)
AttributeError: 'NLPO' object has no attribute '_get_eval_env'.

I also found that no code related to '_get_eval_env'. in this repository, did i miss some key files? IF yes please let me know.

In addition to this, I have a query regarding library usage. I would like to know if it's possible to use 'gymnasium' in place of 'gym' within the NLPO project. If so, could you please guide on how to substitute all relevant 'gym' imports to 'gymnasium'?

UnderStand Mask model to _get_action_masks in LogitsProcessor

In this linecode
I saw that the code sets MaskLogitsProcessorCasualLM Init process uses deepcopy(self._policy_model).eval() and during the generate process, GenerationMixinWithRawScores.sampler executes pre-process distribution and Hook calls to the custom LogitsProcessor. I compared next_token_logits_raw from policy_model is indeed different from next_token_logits from mask_model in the same generate pipeline, what is the meaning of doing here? I really want to know?

Memory issue in metric evals?

Hi all,

I am encountering a gpu memory issue in metric evaluations.

I am using the following metrics:

  metrics:
    - id: meteor
      args: {}
    - id: rouge
    - id: bleu
      args: {}
    - id: bert_score # TODO AM running into cuda memory insufficient here
      args:
        language: en
    - id: cider
    - id: diversity
      args: {}

On monitoring the GPU usage for the card hosting the metric models, I see a steady increase in memory occupied:

initial:
+-------------------------------+----------------------+----------------------+
|   7  Tesla V100-SXM2...  On   | 00000000:00:1D.0 Off |                    0 |
| N/A   51C    P0    71W / 300W |  3514MiB / 32768MiB |      0%      Default |
|                               |                      |                  N/A |
at 200 epochs
+-------------------------------+----------------------+----------------------+
|   7  Tesla V100-SXM2...  On   | 00000000:00:1D.0 Off |                    0 |
| N/A   53C    P0    73W / 300W |  22171MiB / 32768MiB |      0%      Default |
|                               |                      |                  N/A |

Any idea what might be causing this?
Thanks

Pip install error with gym and torch

Hi, I encountered this error when pip installing the rl4lm library using pip install -e . The message says

'extras_require' must be a dictionary whose values are strings or lists of strings containing valid project/version requirement specifiers.

I found the solution to be first do pip install setuptools==65.5.0 pip==21, or even earlier versions according to this issue, though I suspect downgrading pip is not necessary.

As for torch=1.11.0, it is a relatively old package and seems to not support the latest version of Python. I didn't see you indicate your Python version so it would be nice to share it with everyone so that they won't encounter this dependency error. I downgraded from 3.11 to 3.8 and the error resolves.

Hope someone can have a look at this and see if it's my mistake or worth a PR.

Bug while loading t5 base model

I am trying to load t5 base model as per t5_ppo config. Strangely this error pops out. Works fine for t5-small.

	size mismatch for decoder.final_layer_norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for lm_head.weight: copying a param with shape torch.Size([32128, 512]) from checkpoint, the shape in current model is torch.Size([32128, 768]).

Problems with models that don't have the parallelize() function

Hey,
First of all thank you for this amazing repo!
I am trying to employ this repo with a model that is does not have the parallelize() function (led - the longformer encoder-decoder).
Now - from what I have observed - such models are simply wrapped in a DataParallel decorator.
The problem is this causes many bugs that stem from the lack of parallelize function.
For example, many time the get_policy_first_device function is called, which searches for the first_device parameter in the models, which is inserted when parallelize is called (and there are many other issues).
I did notice a similar issue has already been addressed and so I was wondering if there are plans to properly treat such models.
Thanks!

Value is not broadcastable with batch_shape+event_shape

My yaml:

tokenizer:
  model_name: facebook/bart-large-cnn
  padding_side: left
  truncation_side: left
  pad_token_as_eos_token: False

reward_fn:
  id: rouge
  args:
    rouge_type: "rouge1"

datapool:
  id: cnn_daily_mail
  args:
    prompt_prefix: "Summarize: "
    max_size: 500


env:
  n_envs: 1
  args:
    max_prompt_length: 64
    max_episode_length: 100
    terminate_on_eos: True
    prompt_truncation_side: "right"
    context_start_token: 0

alg:
  id: ppo
  args: 
    n_steps: 512
    batch_size: 4
    verbose: 1
    learning_rate: 0.000002
    n_epochs: 5
    ent_coef: 0.0
  kl_div:
    coeff: 0.001
    target_kl: 0.2
  policy:
    id: seq2seq_lm_actor_critic_policy
    args:
      model_name: facebook/bart-large-cnn
      apply_model_parallel: False
      prompt_truncation_side: "right"
      generation_kwargs:
        do_sample: True
        top_k: 50
        min_length: 50
        max_new_tokens: 100          
    
train_evaluation:
  eval_batch_size: 16
  n_iters: 100
  eval_every: 10
  save_every: 1
  metrics:
    - id: meteor
      args: {}
    - id: rouge
    - id: bleu
      args: {}
    - id: bert_score
      args:
        language: en
    # - id: bleurt
    #   args:
    #     config_name: bleurt-large-512
    - id: diversity
      args: {}
    # - id: summaCZS
    #   args:
    #     granularity: sentence
    #     use_ent: True
    #     use_con: False
    # - id: summaCConv
    #   args:
    #     granularity: sentence
  generation_kwargs: 
    do_sample: True
    top_k: 0
    temperature: 0.7
    min_length: 50
    max_new_tokens: 100

My error:

[/content/RL4LMs/scripts/training/train_text_generation.py](https://localhost:8080/#) in main(config_path, project_name, experiment_name, base_path_to_store_results, entity_name, log_to_wandb)
     53             tracker=tracker,
     54         )
---> 55     trainer.train_and_eval()
     56 
     57 

[/content/RL4LMs/rl4lms/envs/text_generation/training_utils.py](https://localhost:8080/#) in train_and_eval(self)
    195         # evaluate on val and test set before fine-tuning once
    196         iter_start = self._trainer_state["current_iter"]
--> 197         self._evaluate_on_datapools(epoch=iter_start)
    198 
    199         # train for given number of iters

[/content/RL4LMs/rl4lms/envs/text_generation/training_utils.py](https://localhost:8080/#) in _evaluate_on_datapools(self, epoch, splits)
    181                                splits: List[str] = ["val", "test"]):
    182         for split in splits:
--> 183             evaluate_on_samples(policy=self._alg.policy,
    184                                 tokenizer=self._tokenizer,
    185                                 samples=self._samples_by_split[split],

[/content/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py](https://localhost:8080/#) in evaluate_on_samples(policy, tokenizer, samples, batch_size, max_prompt_length, metrics, epoch, split_name, tracker, dt_control_token, gen_kwargs)
     39     n_samples = len(samples)
     40     for batch in tqdm(list(get_batch(samples, batch_size)), desc="Evaluating"):
---> 41         batch_generated_texts = generate_text(
     42             policy, tokenizer, batch, max_prompt_length, dt_control_token, gen_kwargs
     43         )

[/content/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py](https://localhost:8080/#) in generate_text(policy, tokenizer, samples, max_prompt_length, dt_control_token, gen_kwargs)
    109         dt_control_token + sample.prompt_or_input_text for sample in samples
    110     ]
--> 111     generated_texts = policy.generate(
    112         tokenizer, prompt_texts, max_prompt_length, gen_kwargs=gen_kwargs
    113     ).gen_texts

[/content/RL4LMs/rl4lms/envs/text_generation/policy/base_policy.py](https://localhost:8080/#) in generate(self, tokenizer, texts, max_prompt_length, input_ids, attention_mask, gen_kwargs)
    254             actions_at_step = gen_tokens[:, step]
    255             distribution = Categorical(logits=raw_logits)
--> 256             log_probs = distribution.log_prob(actions_at_step)
    257             step_wise_logprobs.append(log_probs)
    258             step_wise_actions.append(actions_at_step)

[/usr/local/lib/python3.8/dist-packages/torch/distributions/categorical.py](https://localhost:8080/#) in log_prob(self, value)
    121     def log_prob(self, value):
    122         if self._validate_args:
--> 123             self._validate_sample(value)
    124         value = value.long().unsqueeze(-1)
    125         value, log_pmf = torch.broadcast_tensors(value, self.logits)

[/usr/local/lib/python3.8/dist-packages/torch/distributions/distribution.py](https://localhost:8080/#) in _validate_sample(self, value)
    280         for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
    281             if i != 1 and j != 1 and i != j:
--> 282                 raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
    283                                  format(actual_shape, expected_shape))
    284         try:

ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([100]) vs torch.Size([400]).

By the way, can I use

datapool:
  id: cnn_daily_mail
  args:
    prompt_prefix: "Summarize: "
    max_size: 500

to control the data size I use?

Metric version incompatible

The latest metrics loaded from huggingface such as rouge requires rouge_score>=0.1.2, but rl4lms 0.2.1 requires rouge_score==0.0.4, which is incompatible.
And will cause errors when running the example in readme file.

large difference between val and test on CommonGEN

Hi! Thanks for making this amazing library, this can enable / inspire so much further research!

I have a question regarding the val perf and test perf on CommonGEN. I tried a model with python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/common_gen/t5_ppo.yml

At the end of training, I get

2022-11-05 07:29:35,664 [INFO] val metrics: {'epoch': 199, 'metrics': {'lexical/meteor': 0.2056212306139174, 'lexical/rouge_rouge1': 0.40577034125587874, 'lexical/rouge_rouge2': 0.07797580787551397, 'lexical/rouge_rougeL': 0.306489902042443, 'lexical/rouge_rougeLsum': 0.30644831049690746, 'lexical/bleu': 0.023794832248399733, 'semantic/bert_score': 0.8717959486585849, 'lexical/cider': 0.5675097699920747, 'diversity_metrics/msttr-100': 0.55397, 'diversity_metrics/msttr-100_nopunct': 0.55153, 'diversity_metrics/total_length': 14111, 'diversity_metrics/mean_pred_length': 14.210473313192347, 'diversity_metrics/std_pred_length': 2.524205293570917, 'diversity_metrics/median_pred_length': 14.0, 'diversity_metrics/min_pred_length': 4, 'diversity_metrics/max_pred_length': 20, 'diversity_metrics/distinct-1': 0.16363120969456452, 'diversity_metrics/vocab_size-1': 2309, 'diversity_metrics/unique-1': 1217, 'diversity_metrics/entropy-1': 7.808842726642299, 'diversity_metrics/distinct-2': 0.553438024089038, 'diversity_metrics/vocab_size-2': 7260, 'diversity_metrics/unique-2': 5812, 'diversity_metrics/entropy-2': 11.56123473112017, 'diversity_metrics/cond_entropy-2': 3.8366863460482588, 'diversity_metrics/distinct-3': 0.8145154639175258, 'diversity_metrics/vocab_size-3': 9876, 'diversity_metrics/unique-3': 8887, 'diversity_metrics/entropy-3': 12.997942530275049, 'diversity_metrics/cond_entropy-3': 1.4999314341877525, 'diversity_metrics/total_length-nopunct': 13729, 'diversity_metrics/mean_pred_length-nopunct': 13.8257804632427, 'diversity_metrics/std_pred_length-nopunct': 2.466463891696835, 'diversity_metrics/median_pred_length-nopunct': 14.0, 'diversity_metrics/min_pred_length-nopunct': 4, 'diversity_metrics/max_pred_length-nopunct': 20, 'diversity_metrics/distinct-1-nopunct': 0.16767426615194114, 'diversity_metrics/vocab_size-1-nopunct': 2302, 'diversity_metrics/unique-1-nopunct': 1215, 'diversity_metrics/entropy-1-nopunct': 7.811421016883972, 'diversity_metrics/distinct-2-nopunct': 0.5537845477386935, 'diversity_metrics/vocab_size-2-nopunct': 7053, 'diversity_metrics/unique-2-nopunct': 5661, 'diversity_metrics/entropy-2-nopunct': 11.50109557284849, 'diversity_metrics/cond_entropy-2-nopunct': 3.8349530854466862, 'diversity_metrics/distinct-3-nopunct': 0.8112918334326833, 'diversity_metrics/vocab_size-3-nopunct': 9527, 'diversity_metrics/unique-3-nopunct': 8564, 'diversity_metrics/entropy-3-nopunct': 12.939825963348007, 'diversity_metrics/cond_entropy-3-nopunct': 1.5130166326851369}}

and

2022-11-05 07:30:32,912 [INFO] test metrics: {'epoch': 199, 'metrics': {'lexical/meteor': 0.0028241679555033243, 'lexical/rouge_rouge1': 0.00017256735693609444, 'lexical/rouge_rouge2': 0.0, 'lexical/rouge_rougeL': 0.00017256735693609444, 'lexical/rouge_rougeLsum': 0.00017256735693609444, 'lexical/bleu': 0.0, 'semantic/bert_score': 0.8327858297922011, 'lexical/cider': 0.0, 'diversity_metrics/msttr-100': 0.55321, 'diversity_metrics/msttr-100_nopunct': 0.55286, 'diversity_metrics/total_length': 21570, 'diversity_metrics/mean_pred_length': 14.408817635270541, 'diversity_metrics/std_pred_length': 2.316881153916265, 'diversity_metrics/median_pred_length': 15.0, 'diversity_metrics/min_pred_length': 4, 'diversity_metrics/max_pred_length': 20, 'diversity_metrics/distinct-1': 0.14149281409364858, 'diversity_metrics/vocab_size-1': 3052, 'diversity_metrics/unique-1': 1521, 'diversity_metrics/entropy-1': 7.9544201122946445, 'diversity_metrics/distinct-2': 0.5344990783639715, 'diversity_metrics/vocab_size-2': 10729, 'diversity_metrics/unique-2': 8611, 'diversity_metrics/entropy-2': 11.958045655879152, 'diversity_metrics/cond_entropy-2': 4.0659309616801425, 'diversity_metrics/distinct-3': 0.7975344530577089, 'diversity_metrics/vocab_size-3': 14815, 'diversity_metrics/unique-3': 13359, 'diversity_metrics/entropy-3': 13.506774146114655, 'diversity_metrics/cond_entropy-3': 1.6102318925554278, 'diversity_metrics/total_length-nopunct': 21021, 'diversity_metrics/mean_pred_length-nopunct': 14.042084168336673, 'diversity_metrics/std_pred_length-nopunct': 2.2644693508955878, 'diversity_metrics/median_pred_length-nopunct': 14.0, 'diversity_metrics/min_pred_length-nopunct': 4, 'diversity_metrics/max_pred_length-nopunct': 19, 'diversity_metrics/distinct-1-nopunct': 0.14476000190285904, 'diversity_metrics/vocab_size-1-nopunct': 3043, 'diversity_metrics/unique-1-nopunct': 1517, 'diversity_metrics/entropy-1-nopunct': 7.962618110817082, 'diversity_metrics/distinct-2-nopunct': 0.5329850440483508, 'diversity_metrics/vocab_size-2-nopunct': 10406, 'diversity_metrics/unique-2-nopunct': 8363, 'diversity_metrics/entropy-2-nopunct': 11.891663674273689, 'diversity_metrics/cond_entropy-2-nopunct': 4.066937157170587, 'diversity_metrics/distinct-3-nopunct': 0.7936428690297886, 'diversity_metrics/vocab_size-3-nopunct': 14307, 'diversity_metrics/unique-3-nopunct': 12868, 'diversity_metrics/entropy-3-nopunct': 13.449966343691857, 'diversity_metrics/cond_entropy-3-nopunct': 1.6298900499140245}}

I was wondering why the difference is so large? Am I doing anything incorrectly? Thanks!

Persistent Variance in IMDB

In running experiments on IMDB, I found that there was a very high variance in validation and test set results and I don't fully understand it, so I'm looking for some advice.

Here, I've run PPO for 10 seeds using default hyperparameters

image

First of all, its clear that

  1. there is clearly a large variance in performance at epoch 0, which could be explained by randomness in the eval sampling during decoding
  2. there is a large variance in performance at epoch 50, which could be explained by randomness in RL

But together, we see runs that perform best at epoch 0 generally perform best on perplexity at epoch 50, which I can't explain. Here's the top 5 and bottom 5 based on initial perplexity scores, plotted against each other

image

Given that all models should be initialized to the pretrained model, there should be no randomness in initialization. So I'm confused as to how this is possible. Getting a lucky random seed for the initial validation should not affect the random seed for RL training, so why does the model that performs best at epoch 0 generally perform best at epoch 50?

Finally, I think the variance in results is high enough that I would recommend using 10 seeds for RL4LMs experiments

Error with Accelerate integration + NLPO

Hi, I'm trying to use the Accelerate integration, because otherwise with NLPO I cannot run a small model (200M parameter) with 512 tokens length, not even in a 80GB A100. That makes NLPO impractical for almost any problem, unless you can use Accelerate / Deepspeed or any other integration for splitting models among GPUs and CPUs. However, when trying to do so, I receive the following error:

Traceback (most recent call last):
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/scripts/training/train_text_generation.py", line 95, in <module>
    main(
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/scripts/training/train_text_generation.py", line 56, in main
    trainer = OnPolicyTrainer(
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 190, in __init__
    self._setup()
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 207, in _setup
    self._alg = build_alg(
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 151, in build_alg
    alg = wrapper(
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 451, in wrap_onpolicy_alg
    alg = OnPolicyAlgText(alg_kwargs, kl_coeff, tracker, accelerator, target_kl, norm_reward)
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 126, in __init__
    super().__init__(**alg_kwargs)
TypeError: __init__() got an unexpected keyword argument 'accelerator'

Also, I'm curious about how you were able to carry out the benchmarks in your paper, as I don't know any single GPU bigger than 80GB and even there I cannot run NLPO with sentences longer than 128 tokens. How did you do it? Is there maybe something I'm missing? @rajcscw @jmhessel @rajammanabrolu @JulesGM @akifumi-wachi-4

Thank you very much for this amazing work!! :)

CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`

@jmhessel @dirkgr @schmmd @iellenberger

Ran python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/iwslt2017/t5_ppo.yml

with the following config:

tokenizer:
  model_name: "google/mt5-base"
  padding_side: right
  truncation_side: right
  truncation: True
  padding: True
  max_length: 128
  # pad_token_as_eos_token: False

reward_fn:
  id: meteor 
  
datapool:
  id: wmt16
  args:
    train_path: "data/train.csv"
    eval_path: "data/eval.csv"
    test_path: "data/test.xlsx"


env:
  n_envs: 10
  args:
    max_prompt_length: 128
    max_episode_length: 128
    terminate_on_eos: True
    prompt_truncation_side: "right"
    context_start_token: 0

alg:
  id: ppo
  args: 
    n_steps: 2
    batch_size: 20
    verbose: 2
    learning_rate: 0.000001
    n_epochs: 5
    ent_coef: 0.0
  kl_div:
    coeff: 0.001
    target_kl: 0.2
  policy:
    id: seq2seq_lm_actor_critic_policy
    args:
      model_name: "google/mt5-base"
      apply_model_parallel: True
      prompt_truncation_side: "right"
      generation_kwargs:
        do_sample: True
        num_beams: 3
        max_length: 128
        length_penalty: 0.85
        repetition_penalty: 2.0
        max_new_tokens: 128

    
train_evaluation:
  eval_batch_size: 1
  n_iters: 10
  eval_every: 10
  save_every: 1
  metrics:
    - id: meteor
      args: {}
    - id: sacre_bleu
      args:
        tokenize: "intl"
  generation_kwargs:
    do_sample: True
    num_beams: 3
    max_length: 128
    length_penalty: 0.85
    max_new_tokens: 128
    repetition_penalty: 2.0

data_pool:

class WMT(TextGenPool):

    @classmethod
    def prepare(cls,
                split: str,
                train_path: str,
                eval_path: str,
                test_path: str
                ):

        if split == 'train':
            dataset = pd.read_csv(train_path, nrows=100)
        elif split == 'val':
            dataset = pd.read_csv(eval_path, nrows=100)
        elif split == 'test':
            dataset = pd.read_excel(test_path, engine='openpyxl')

        samples = []
        for ix, item in tqdm(dataset.iterrows(),
                             desc="Preparing dataset",
                             total=len(dataset)):

            prompt = item['prefix'] + item['input_text']
            reference = item['target_text']

            sample = Sample(id=f"{split}_{ix}",
                            prompt_or_input_text=prompt,
                            references=[reference]
                            )
            samples.append(sample)

        pool_instance = cls(samples)
        return pool_instance

However it results in:

/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [595,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [595,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
...
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [596,0,0], thread: [124,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [596,0,0], thread: [125,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [596,0,0], thread: [126,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [596,0,0], thread: [127,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Evaluating:   0%|                                                                                            | 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "scripts/training/train_text_generation.py", line 71, in <module>
    args.log_to_wandb)
  File "scripts/training/train_text_generation.py", line 42, in main
    trainer.train_and_eval()
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 198, in train_and_eval
    self._evaluate_on_datapools(epoch=iter_start)
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 193, in _evaluate_on_datapools
    gen_kwargs=self._eval_gen_kwargs)
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py", line 41, in evaluate_on_samples
    dt_control_token, gen_kwargs)
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py", line 99, in generate_text
    gen_kwargs=gen_kwargs)["gen_texts"]
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/policy.py", line 304, in generate
    **generation_kwargs_)
  File "/home/user/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/hf_generation_utils.py", line 1199, in generate
    inputs_tensor, model_kwargs, model_input_name
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/hf_generation_utils.py", line 535, in _prepare_encoder_decoder_kwargs_for_generation
    **encoder_kwargs)
  File "/home/user/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/conda/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py", line 1044, in forward
    output_attentions=output_attentions,
  File "/home/user/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/conda/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py", line 675, in forward
    output_attentions=output_attentions,
  File "/home/user/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/conda/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py", line 581, in forward
    output_attentions=output_attentions,
  File "/home/user/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/conda/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py", line 500, in forward
    query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)
  File "/home/user/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/conda/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/user/conda/lib/python3.7/site-packages/torch/nn/functional.py", line 1848, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`

Error encountered in running the scripts at the read me

Hi ,
Following your readme scripts, i ran into an error after executing "!python scripts/training/train_text_generation.py --config_path scripts/task_configs/summarization/t5_ppo.yml"
While I identify the first 4 lines that might be from my current TensorFlow library, does anyone have some advice?

The error is given below
,,,
2022-10-09 01:13:11.489647: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-10-09 01:13:11.945129: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-10-09 01:13:13.866126: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-10-09 01:13:13.866414: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-10-09 01:13:13.866456: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
[10/09/22 01:13:18] DEBUG 2022-10-09 01:13:18,221 �]8;id=937946;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=758244;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
huggingface.co:443
DEBUG 2022-10-09 01:13:18,301 �]8;id=611467;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=165305;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://huggingface.co:443
"HEAD /t5-base/resolve/main/t
okenizer_config.json
HTTP/1.1" 404 0
DEBUG 2022-10-09 01:13:18,309 �]8;id=749344;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=742820;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
huggingface.co:443
DEBUG 2022-10-09 01:13:18,373 �]8;id=656859;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=289329;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://huggingface.co:443
"HEAD /t5-base/resolve/main/c
onfig.json HTTP/1.1" 200 0
DEBUG 2022-10-09 01:13:18,384 �]8;id=794412;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=582880;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
huggingface.co:443
DEBUG 2022-10-09 01:13:18,454 �]8;id=78827;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=233936;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://huggingface.co:443
"HEAD /t5-base/resolve/main/t
okenizer_config.json
HTTP/1.1" 404 0
DEBUG 2022-10-09 01:13:18,464 �]8;id=779280;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=227910;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
huggingface.co:443
DEBUG 2022-10-09 01:13:18,518 �]8;id=845759;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=515264;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://huggingface.co:443
"HEAD /t5-base/resolve/main/s
piece.model HTTP/1.1" 200 0
DEBUG 2022-10-09 01:13:18,529 �]8;id=17405;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=668408;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
huggingface.co:443
DEBUG 2022-10-09 01:13:18,584 �]8;id=412578;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=548906;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://huggingface.co:443
"HEAD /t5-base/resolve/main/t
okenizer.json HTTP/1.1" 200 0
DEBUG 2022-10-09 01:13:18,596 �]8;id=46178;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=346117;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
huggingface.co:443
DEBUG 2022-10-09 01:13:18,652 �]8;id=359239;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=238891;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://huggingface.co:443
"HEAD /t5-base/resolve/main/a
dded_tokens.json HTTP/1.1"
404 0
DEBUG 2022-10-09 01:13:18,662 �]8;id=153299;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=918576;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
huggingface.co:443
DEBUG 2022-10-09 01:13:18,720 �]8;id=938154;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=441526;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://huggingface.co:443
"HEAD /t5-base/resolve/main/s
pecial_tokens_map.json
HTTP/1.1" 404 0
DEBUG 2022-10-09 01:13:18,730 �]8;id=671520;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=492011;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
huggingface.co:443
DEBUG 2022-10-09 01:13:18,786 �]8;id=486197;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=77872;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://huggingface.co:443
"HEAD /t5-base/resolve/main/t
okenizer_config.json
HTTP/1.1" 404 0
DEBUG 2022-10-09 01:13:18,797 �]8;id=750980;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=562466;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
huggingface.co:443
DEBUG 2022-10-09 01:13:18,867 �]8;id=973142;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=345840;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://huggingface.co:443
"HEAD /t5-base/resolve/main/c
onfig.json HTTP/1.1" 200 0
/tigress/pe3955/RL4LMs/rl4lms/envs/text_generation/reward.py:133: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate
self._metric = load_metric("rouge")
[10/09/22 01:13:19] DEBUG 2022-10-09 01:13:19,023 �]8;id=803810;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=277738;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
s3.amazonaws.com:443
DEBUG 2022-10-09 01:13:19,107 �]8;id=615061;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=719707;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://s3.amazonaws.com:443
"HEAD /datasets.huggingface.c
o/datasets/metrics/rouge/roug
e.py HTTP/1.1" 200 0
DEBUG 2022-10-09 01:13:19,118 �]8;id=470668;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=350537;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1): raw.githubus
ercontent.com:443
DEBUG 2022-10-09 01:13:19,230 �]8;id=551399;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=117604;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG] https://raw.githubuse
rcontent.com:443 "HEAD /huggi
ngface/datasets/2.5.1/metrics
/rouge/rouge.py HTTP/1.1" 200
0
DEBUG 2022-10-09 01:13:19,283 �]8;id=141867;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=887993;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
s3.amazonaws.com:443
DEBUG 2022-10-09 01:13:19,392 �]8;id=443822;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=771490;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://s3.amazonaws.com:443
"HEAD /datasets.huggingface.c
o/datasets/metrics/meteor/met
eor.py HTTP/1.1" 200 0
DEBUG 2022-10-09 01:13:19,402 �]8;id=136782;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=392469;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1): raw.githubus
ercontent.com:443
DEBUG 2022-10-09 01:13:19,523 �]8;id=570482;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=370600;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG] https://raw.githubuse
rcontent.com:443 "HEAD /huggi
ngface/datasets/2.5.1/metrics
/meteor/meteor.py HTTP/1.1"
200 0
[nltk_data] Downloading package wordnet to /home/pe3955/nltk_data...
[nltk_data] Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/pe3955/nltk_data...
[nltk_data] Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/pe3955/nltk_data...
[nltk_data] Package omw-1.4 is already up-to-date!
DEBUG 2022-10-09 01:13:19,878 �]8;id=770699;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=155265;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�
[DEBUG] Starting new HTTPS
connection (1):
s3.amazonaws.com:443
DEBUG 2022-10-09 01:13:19,953 �]8;id=330557;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=601128;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#456�\456�]8;;�
[DEBUG]
https://s3.amazonaws.com:443
"HEAD /datasets.huggingface.c
o/datasets/metrics/rouge/roug
e.py HTTP/1.1" 200 0
DEBUG 2022-10-09 01:13:19,963 �]8;id=800289;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py�\connectionpool.py�]8;;�:�]8;id=380267;file:///home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/urllib3/connectionpool.py#1003�\1003�]8;;�\

Evaluating: 0%| | 0/134 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/tigress/pe3955/RL4LMs/scripts/training/train_text_generation.py", line 66, in
main(args.config_path,
File "/tigress/pe3955/RL4LMs/scripts/training/train_text_generation.py", line 42, in main
trainer.train_and_eval()
File "/tigress/pe3955/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 197, in train_and_eval
self._evaluate_on_datapools(epoch=iter_start)
File "/tigress/pe3955/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 183, in _evaluate_on_datapools
evaluate_on_samples(policy=self._alg.policy,
File "/tigress/pe3955/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py", line 39, in evaluate_on_samples
batch_generated_texts = generate_text(
File "/tigress/pe3955/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py", line 96, in generate_text
generated_texts = policy.generate(tokenizer,
File "/tigress/pe3955/RL4LMs/rl4lms/envs/text_generation/policy.py", line 297, in generate
self.get_policy_first_device()),
File "/tigress/pe3955/RL4LMs/rl4lms/envs/text_generation/policy.py", line 519, in get_policy_first_device
return self._policy_model.encoder.first_device
File "/home/pe3955/.conda/envs/rl4lm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1185, in getattr
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'T5Stack' object has no attribute 'first_device
,,,

BART supervised

I have tried using BART as a seq2seq type model, from huggingface facebook/bart-large. This howerver throws an error saying that .parallelise doesnt exit. Has anyone been able to finetune bart using this repo?

Bloom Supporting

The repository uses transformers version 4.18, which does not support bloom, is there any way to use bloom as the initial policy for training?

Numbeams

Tried to set num_beams parameter for generation to 3, but got an error

config:

tokenizer:
  model_name: "t5-base"
  padding_side: right
  truncation_side: right
  truncation: True
  padding: True
  max_length: 128
  # pad_token_as_eos_token: False

reward_fn:
  id: meteor 
  
datapool:
  id: wmt16
  args:
    train_path: "data/train.csv"
    eval_path: "data/eval.csv"
    test_path: "data/test.xlsx"


env:
  n_envs: 10
  args:
    max_prompt_length: 128
    max_episode_length: 128
    terminate_on_eos: True
    prompt_truncation_side: "right"
    context_start_token: 0

alg:
  id: ppo
  args: 
    n_steps: 2
    batch_size: 20
    verbose: 2
    learning_rate: 0.000001
    n_epochs: 5
    ent_coef: 0.0
  kl_div:
    coeff: 0.001
    target_kl: 0.2
  policy:
    id: seq2seq_lm_actor_critic_policy
    args:
      model_name: "t5-base"
      apply_model_parallel: True
      prompt_truncation_side: "right"
      generation_kwargs:
        do_sample: True
        num_beams: 3
        max_length: 128
        length_penalty: 0.85
        repetition_penalty: 2.0
        max_new_tokens: 128

    
train_evaluation:
  eval_batch_size: 1
  n_iters: 10
  eval_every: 10
  save_every: 1
  metrics:
    - id: meteor
      args: {}
    - id: sacre_bleu
      args:
        tokenize: "intl"
  generation_kwargs:
    do_sample: True
    num_beams: 3
    max_length: 128
    length_penalty: 0.85
    max_new_tokens: 128
    repetition_penalty: 2.0

error:

Evaluating:   0%|                                                                                                                               | 0/1 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "scripts/training/train_text_generation.py", line 71, in <module>
    args.log_to_wandb)
  File "scripts/training/train_text_generation.py", line 42, in main
    trainer.train_and_eval()
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 198, in train_and_eval
    self._evaluate_on_datapools(epoch=iter_start)
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 193, in _evaluate_on_datapools
    gen_kwargs=self._eval_gen_kwargs)
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py", line 41, in evaluate_on_samples
    dt_control_token, gen_kwargs)
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py", line 99, in generate_text
    gen_kwargs=gen_kwargs)["gen_texts"]
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/policy.py", line 324, in generate
    log_probs = distribution.log_prob(actions_at_step)
  File "/home/user/conda/lib/python3.7/site-packages/torch/distributions/categorical.py", line 117, in log_prob
    self._validate_sample(value)
  File "/home/user/conda/lib/python3.7/site-packages/torch/distributions/distribution.py", line 277, in _validate_sample
    format(actual_shape, expected_shape))
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([10]) vs torch.Size([30]).

'GPT2Model' object has no attribute 'first_device'

I get the following error when running python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/dialog/gpt2_ppo.yml. I have double-checked that transformers==4.18.0.

Traceback (most recent call last):
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/scripts/training/train_text_generation.py", line 84, in <module>
    main(
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/scripts/training/train_text_generation.py", line 55, in main
    trainer.train_and_eval()
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 232, in train_and_eval
    self._alg.learn(self._n_steps_per_iter)
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/algorithms/ppo/ppo.py", line 342, in learn
    return super().learn(
  File "/opt/anaconda3/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 247, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 384, in collect_rollouts
    rollout_info = self.generate_batch(
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 159, in generate_batch
    gen_output = self.policy.generate(
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/policy/base_policy.py", line 230, in generate
    inputs=input_ids.to(self.get_policy_first_device()),
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/policy/causal_policy.py", line 259, in get_policy_first_device
    self._policy_model.transformer.first_device
  File "/opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1185, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'GPT2Model' object has no attribute 'first_device'

Implementing self-play

Hello

I would like to implement self-play dialogue training.
For that I guess I need to modify episode rollout process by adding formatting like speaker id on the start of each line. I'd also like to try holding some model buffer of previous checkpoints and use them as one of the conversants to avoid model overfitting to itself.

The obvious place for it is implementing a new policy that provides formatted generation results and holds previous checkpoints in the buffer.

Is there any better place to implement this? Anything I should consider library-wise while implementing it?

Any advice would be appreciated, thanks in advance!

Self-designed model

Hi there, I have a simple question. I want to know whether I can use this tool to train a self-designed model which cannot be found in huggingface. The input of our model is different from regular LM.

_pickle.UnpicklingError: pickle data was truncated

I am trying to get RL4LMs to work, and to achieve this, I've made the docker image using the instructions in the README file. After building the container, I tried running the following command in the container(under Quick start):

python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/summarization/t5_ppo.yml

(docker run -it rl4lms /bin/sh
followed by
python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/summarization/t5_ppo.yml)

However, I'm getting an UnpicklingError. The complete debugging info right before the error is as follows:

To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-02 17:33:13.379149: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-03-02 17:33:13.379385: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-03-02 17:33:13.379401: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-03-02 17:33:45.988942: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-02 17:33:48.806522: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-03-02 17:33:48.806802: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-03-02 17:33:48.806862: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Killed
# Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/forkserver.py", line 280, in main
    code = _serve_one(child_r, fds,
  File "/opt/conda/lib/python3.8/multiprocessing/forkserver.py", line 319, in _serve_one
    code = spawn._main(child_r, parent_sentinel)
  File "/opt/conda/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
_pickle.UnpicklingError: pickle data was truncated

I am using a macbook pro 2019 with intel chip 2.4 GHz Quad-Core Intel Core i5. Should I make any changes to the Dockerfile to be able to run this or is there any other script I can try?

note: I've also tried installing this on my machine (and a different machine with M1 chip) but have run into issues with incompatibility of specific package versions with python 3.10 and 3.11 I have on these two machines.

Error when trying to load a checkpoint from Transformers after RL training

Hi, I have tried training a MarianMT model by maximizing bert_score, and everything worked fine until I was trying to load these weights from transformers, when I encountered an issue.

I have created a folder with the last checkpoint binary and tried to run AutoModelForSeq2SeqLM.from_pretrained(<folder_name>) from there, but threw the following error:

OSError: Error no file named pytorch_model.bin found in directory <directory> but there is a file for Flax weights. Use `from_flax=True` to load this model from those weights.

Then, I looked into transformers documentation and saw that for loading flax models there should be a flax_model.msgpack file in the directory, so I renamed the checkpoint binary to that, and retried with from_flax=True in the from_pretrained call. However, there is still an issue with loading this model:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 446, in from_pretrained
    return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1843, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/models/marian/modeling_marian.py", line 1281, in __init__
    self.model = MarianModel(config)
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/models/marian/modeling_marian.py", line 1090, in __init__
    self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 132, in __init__
    assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
AssertionError: Padding_idx must be within num_embeddings

I guess if this library is suitable for training models from transformers library, there should be no problem in loading with transformers the models after training via reinforcement learning with the library, so I would like to ask if I am doing something wrong when loading the models, or what is the correct way of loading checkpoints with transformers after training with RL4LMs.

Thanks in advance :)

passing extra variable to the forward function

Hey,
I am currently using your repo to finetune a Longformer model.
The problem is this model requires to pre-define a global attention mask (in addition to the regular attention mask), which defines which of the tokens get an extra "global attention head".
So my question is - is there an easy way to pass this variable, that does not require to skim through the code and locate every calling of the forward functions?
I other words- is there an easy way to pass extra model_kwargs?
Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.