Giter Club home page Giter Club logo

halos's People

Contributors

kawin-contextual-ai avatar kawine avatar samuelzxu avatar xwinxu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

halos's Issues

Request for details and assistance on PPO Experiments with SFT+PPO training

Hello Developers,

Firstly, I would like to thank you for the excellent work on this repository and for sharing the plots on other issues. I'm currently utilizing your library to train a model using sft+ppo, and I've successfully replicated the sft experiment as per the results shared ContextualAI/HALOs/issues/13.

However, I'm experiencing negligible improvement with the PPO part of the training. Could you provide the details of your PPO experiments? I noted in a previous comment that the preferential tuning was run significantly longer, so I adjusted the PPO epochs to 3 in my experiments. Are these adjustments in line with what was done in your experiments? Additionally, could you elaborate on how and when to decide which checkpoint to use for downstream tasks, especially for PPO, DPO, and KTO scenarios?

Link to my plots: l7b_ppo_0419
Screenshot 2024-04-22 at 11 11 32 AM

Here are the commands I used for my experiments:

  • SFT Training Command
python train.py loss=sft model=llama7b datasets=[shp,hh,oasst] exp_name=l7b_sft_0416 mode=train ++cache_dir=/data/models wandb.project=l7b_sft_0416
  • PPO Training (3 Epochs)
# Updated n_epochs in config.yaml to 3
python train.py loss=ppo model=llama7b datasets=[shp,hh,oasst] exp_name=l7b_ppo_0419 mode=train ++cache_dir=./data/models ++model.load_from=l7b_sft_0416/LATEST/policy.pt wandb.project=l7b_ppo_0419

Additional Query:

  • When conducting sft training, it calculates train and validation losses using train dataset splits. If I use the same dataset for ppo, how can I ensure that I am not retraining on the train split inadvertently? Furthermore, when using stratified datasets for both sft and preferential tuning, do you recommend holding out different data points for each, and is this approach considered best practice?
  • Based on your plots and results shared, am I correct in understanding that you had a batch size of 32 and conducted 200k steps of sft training, which equates to training on approximately 6.4 million datapoints, DPO on roughly 9.5 million datapoints (300k * 32), and KTO on about 17.6 million datapoints?

Thank you for any guidance or insights you can provide.

Comments in KTO Trainer `forward()`

Hi there,

I'm reading through the forward() function in KTO Trainer, and in the function signature it states that if read in correctly, the sizes of chosen and rejected logps should be batch_size/2. However, this doesn't make sense to me because this sounds like a limitation for Paired preference training rather than the unpaired training method of kto.

Here's comment from lines 875-877 of trainers.py:

chosen_logps: log probabilities of chosen examples (should be batch size / 2 if data was read in correctly)
rejected_logps: log probabilities of rejected examples (should be batch size / 2 if data was read in correctly)
KL_logps: log probabilities of the unmatched y'|x (used to estimate the KL divergence between policy and reference; should be batch size)

Please let me know if this makes sense, Im happy to open a PR.

Error fix for evaluation script `eval.py`

Thank you for maintaining such an important repository. While using your repository, I was faced with some issues using the eval.py, and wanted to check if the following changes required are indeed valid.

  1. In line 86 of eval.py, we need to modify reference_model.resize_token_embeddings(len(tokenizer)) into the following:
if config.loss.name == 'ppo':
    reference_model.resize_token_embeddings(len(tokenizer))

this is because we have resized the token embeddings for the reference model only if the loss was ppo (refer to line 170 of train.py)

  1. When we use the eval mode of eval.py, we need to change the trainer type from BasicTrainer to DPOTrainer, otherwise we are faced with the NotImplementedError while calling the get_batch_metrics method.

Once again, thank you for maintaining an amazing repository which is very easy to use. Thank you so much!

Is there a problem with training?

Initially, I used KTO for training, and the loss did not converge at all, as shown in the following training result graph.
725bd748-1d85-4d80-90ff-bd2275031467

Later, based entirely on llama7b and hh data, I used the script you provided exactly: python train.py loss=sft model=llama7b datasets=[hh] exp_name=llama7b_sft mode=train ++cache_dir=/data/models, and the training result graph is as follows:
image (1)
image

The only difference in the SFT training from yours is that I set use_flash_attention to false.

Llama 3 compatibility

I tried two llama 3 8B models from huggingface by creating a new config/model/.yaml with

name_or_path: meta-llama/Meta-Llama-3-8B
and also
name_or_path: NousResearch/Meta-Llama-3-8B

I am able to train the SFT model using the command
python train.py loss=sft model=llama_3_8b datasets=[shp,hh] exp_name=l38b_sft_0505 mode=train ++cache_dir=./data/models

However I face this error when I train the PPO model using the command
python train.py loss=ppo model=llama_3_8b datasets=[shp,hh] exp_name=l38b_ppo_0517 mode=train ++cache_dir=./data/models ++model.load_from=l38b_sft_0505/LATEST/policy.pt

Stacktrace:

Making experiment directory ./data/models/l38b_ppo_0517
no FSDP port specified; using open port for FSDP: 46451
seed: 1
exp_name: l38b_ppo_0517
datasets:
- shp
- hh
mode: train
debug: false
use_fsdp: true
fsdp_port: 46451
wandb:
  enabled: true
  entity: null
  project: l38b_ppo_0517
cache_dir: ./data/models
local_run_dir: ./data/models/l38b_ppo_0517
do_first_eval: true
minimum_log_interval_secs: 1.0
intermediate_checkpoints: false
trainer: BasicTrainer
lr: 5.0e-07
n_epochs: 1
n_examples: null
optimizer: RMSprop
warmup_steps: 150
eval_every: 20000
n_samples: 128
samples_dir: samples/
n_eval_examples: 512
saved_policy: ./data/models/l38b_ppo_0517/LATEST/policy.pt
top_p: 0.95
human_prefix: '

  <|user|>

  '
assistant_prefix: '

  <|assistant|>

  '
human_suffix: ''
assistant_suffix: ''
frac_unique_desirable: 1.0
frac_unique_undesirable: 1.0
model:
  name_or_path: NousResearch/Meta-Llama-3-8B
  tokenizer_name_or_path: null
  load_from: l38b_sft_0505/LATEST/policy.pt
  block_name: LlamaDecoderLayer
  policy_dtype: bfloat16
  fsdp_policy_mp: null
  reference_dtype: bfloat16
  max_grad_norm: 10.0
  v_head_max_grad_norm: 0.1
  max_length: 2048
  max_prompt_length: 1024
  activation_checkpointing: true
  batch_size: 32
  gradient_accumulation_steps: 1
  eval_batch_size: 16
  use_flash_attention: true
loss:
  name: ppo
  ppo_epochs: 1
  cliprange: 0.5
  trainer: PPOTrainer
  dataloader: UnpairedPreferenceDataLoader
  lam: 0.95
  gamma: 0.99
  critic_coef: 0.01
  KL_coef: 0.1
  use_reference_model: true

================================================================================
Writing to aid-nrt-slurm-bm-gpu-b4-8-ad1-005:./data/models/l38b_ppo_0517
================================================================================
building policy
You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 10.43it/s]
Error executing job with overrides: ['loss=ppo', 'model=llama_3_8b', 'datasets=[shp,hh]', 'exp_name=l38b_ppo_0517', 'mode=train', '++cache_dir=./data/models', '++model.load_from=l38b_sft_0505/LATEST/policy.pt', '++wandb.project=l38b_ppo_0517']
Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 304, in hf_raise_for_status
    response.raise_for_status()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/requests/models.py", line 1021, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/models.py", line 95, in from_pretrained
    filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin")
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1261, in hf_hub_download
    metadata = get_hf_file_metadata(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1674, in get_hf_file_metadata
    r = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 369, in _request_wrapper
    response = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 393, in _request_wrapper
    hf_raise_for_status(response)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 315, in hf_raise_for_status
    raise EntryNotFoundError(message, response) from e
huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: Root=1-66483dd3-1fff68591f1d7b9b7a1e9d08;04442bd4-7883-4026-8e20-334b328b960d)

Entry Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 304, in hf_raise_for_status
    response.raise_for_status()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/requests/models.py", line 1021, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin.index.json

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/train.py", line 231, in <module>
    main()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/home/rosridha/POCs/rlhf/HALOs/train.py", line 132, in main
    policy = model_class.from_pretrained(
  File "/home/rosridha/POCs/rlhf/HALOs/models.py", line 101, in from_pretrained
    index_file_name = hf_hub_download(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1261, in hf_hub_download
    metadata = get_hf_file_metadata(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1674, in get_hf_file_metadata
    r = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 369, in _request_wrapper
    response = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 393, in _request_wrapper
    hf_raise_for_status(response)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 315, in hf_raise_for_status
    raise EntryNotFoundError(message, response) from e
huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: Root=1-66483dd3-495ad9e044bc65914d189248;d2584051-b677-439d-9982-cc1f1fd96021)

Entry Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin.index.json.

Hidden State mapping to two value nodes instead of 1

Hi,

I'm confused on why you've defined the value head as you did in models.py. Namely, the value head as it is will output two numbers instead of 1, since you're mapping from the (2,4096) final hidden state to a (2,1) dimension tensor for the final value. It looks like you're missing half the hidden states. I would expect for it to map from a flattened version of the final hidden state to a single node.

As a sanity check I looked for where this was used and in line 1114 of trainers.py, I noticed that you're only taking in the first value in this (2,1) vector.

Can you tell me why you've made this design choice? I feel like I'm misinterpreting something here.

Compatibility with quantized embeddings

Hi,

Firstly, thanks for the awesome work!

I want to use KTO with a quantized Mistral model but am getting pickle errors from the multiprocessing thread, probably since that changes the Embedding layers to be nn.Linear4bit instead of just nn.Linear.

File "~/HALOs/train.py", line 250, in main
    mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, tokenizer, train_iterator, eval_iterator, policy, reference_model), join=True)
  File "~/conda/env/halos/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 241, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "~/conda/env/halos/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    process.start()
  File "~/conda/env/halos/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function Embedding.forward at 0x7f75fa914160>: it's not the same object as torch.nn.modules.sparse.Embedding.forward

I'm thinking a workaround would be to use multiprocess instead of multiprocessing to use dill instead of pickle, but haven't been successful with that yet... do you have any suggestions?

Gradient Clipping for FSDP

Hi! Thank you for maintaining such a valuable repository.
I would like to suggest a minor fix regarding gradient clipping. For FSDP, we should not use torch.nn.utils.clip_grad_norm_ (relevant issue), but instead directly call the clip_grad_norm_ method of the FSDP module. Thus, I would like to suggest modifying the following:

HALOs/trainers.py

Lines 453 to 455 in f9f7826

def clip_gradient(self):
"""Clip the gradient norm of the parameters of a non-FSDP policy."""
return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.model.max_grad_norm).item()

into the following:

        def clip_gradient(self):
            """Clip the gradient norm of the parameters."""
            if self.fsdp:
                return self.policy.clip_grad_norm_(self.config.model.max_grad_norm).item()
            return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.model.max_grad_norm).item()

Thank you so much!

How to sample from HF models?

Hi, this is great great project!

I'm thinking about reproducing your scores in the paper first. Specifically from the model ContextualAI/archangel_sft_llama7b. How to sample from this HF model? Right now the eval.py takes in a path from /data/..., which exists only if I trained it myself.

In dataloder

pairs: List[Tuple[int, int]] = field(default_factory=list) # indices in responses, where i > j in pair (i,j) is a preference

In this line, does "i > j" mean "score(i) > score(j)"? I'm a bit confused, thank you for your clarification!
image

Can you provide a clear description of the dataset structure we can use for our custom dataset.

Something like this?

From the Huggingface DPOTrainer docs:

dpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}

Question about the KL term in the loss function

For y ~ p_chosen, the KL term for the loss is derived as KL(p_policy(y_rejected|x') || p_reference(y_rejected|x'))), however according to technical report, the KL divergence should be calculated over the entire dataset. Is there a reason x' for above only includes only rejected inputs and not input x's from the entire batch (i.e include y_chosen too) ? I can guess maybe numerical stability could be the reason to make sure that two terms of the loss aren't correlated but want to make sure I am not missing something here.

Losses list appears to be empty for loss=DPO

Hi,

I got this to run using fsdp. When I print the metrics they are all strangely empty. I follow the PairedPreferenceTrainer class, which calls self.loss. When following self.loss, loss is defined in DPOTrainer. I add a print right after line :699 -> this is oddly an empty tensor. Any ideas?

Thanks!

image

ERROR:None of the inputs have requires_grad=True. Gradients will be None

Computing eval metrics: 0%| | 0/86 [00:00<?, ?it/s]/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

a few queries

Hello Kawin and authors,

Thanks for sharing the excellent work. I enjoyed reading the technical report it was crisp and concise, I appreciate the efforts put in to achieve this clarity.

I have a couple of questions:

  1. In the KTO loss function, the second term of expected KL divergence can be interpreted as the average reward score obtained over 'rejected' responses over all input prompts (in desired response case). In this sense, it is similar to DPO where the second term is 'reward' for the rejected response of that particular x but here we are replacing it with the average reward for 'rejected' response from all the xs.
    Is this understanding correct?
  2. I saw in the code that there is KTOZero implementation where the expected KL divergence is replaced with 0. I am curious what was its finding.
  3. On a lighter note, I would like to know what is the font used in the technical report.

thanks,
Onkar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.