Giter Club home page Giter Club logo

d3po's Introduction

  • πŸ‘‹ Hi, I’m Kai Yang(杨恺)
  • πŸ‘€ I’m interested in AI
  • 🌱 I’m currently learning Reinforcement Learning
  • πŸ’žοΈ I’m a master's student currently studying at THU SIGS
  • πŸ“« You can contact me through [email protected]

d3po's People

Contributors

yk7333 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

d3po's Issues

LICENSE

Thanks a lot for sharing your code! Could you also add a LICENSE file to your repo such that usage conditions and restrictions are clear?

assert num_timesteps == config.sample.num_steps AssertionError

Thank you for your excellent work!
I attempted to perform RLHF on anything-v4.5's human body defects based on the dataset and prompts you provided publicly. However, an error occurred during training, same to the title.
The training details conducted in runpod are provided below.
Additionally, I upload base.py and train.py to https://huggingface.co/datasets/sdtana/anything-v4.5_dpo.
As a non-specialist, I apologize for reaching out to you directly without undergoing extensive testing.
Could you provide advice on where I made a mistake?

I1202 14:25:55.659013 140285781463488 logging.py:60]
allow_tf32: true
logdir: logs
mixed_precision: fp16
num_checkpoint_limit: 10
num_epochs: 5
pretrained:
model: /workspace/anything-v4.5
revision: main
prompt_fn: anything_prompt
prompt_fn_kwargs: {}
resume_from: ''
reward_fn: jpeg_compressibility
run_name: anythingdpo_2023.12.02_14.25.53
sample:
batch_size: 2
eta: 1.0
guidance_scale: 5.0
num_batches_per_epoch: 100
num_steps: 20
save_interval: 100
save_freq: 1
seed: 42
train:
adam_beta1: 0.9
adam_beta2: 0.999
adam_epsilon: 1.0e-08
adam_weight_decay: 0.0001
adv_clip_max: 5
batch_size: 1
beta: 0.1
cfg: true
gradient_accumulation_steps: 1
json_path: /workspace/data/epoch1/json
learning_rate: 1.0e-05
max_grad_norm: 1.0
num_inner_epochs: 1
sample_path: /workspace/data/epoch1
save_interval: 50
timestep_fraction: 1.0
use_8bit_adam: false
use_lora: true

text_config_dict is provided which will be used to initialize CLIPTextConfig. The value text_config["id2label"] will be overriden.
I1202 14:26:06.847579 140285781463488 logging.py:60] ***** Running training *****
I1202 14:26:06.847816 140285781463488 logging.py:60] Num Epochs = 5
I1202 14:26:06.847866 140285781463488 logging.py:60] Sample batch size per device = 2
I1202 14:26:06.847907 140285781463488 logging.py:60] Train batch size per device = 1
I1202 14:26:06.847945 140285781463488 logging.py:60] Gradient Accumulation steps = 1
I1202 14:26:06.847987 140285781463488 logging.py:60]
I1202 14:26:06.848022 140285781463488 logging.py:60] Total number of samples per epoch = 200
I1202 14:26:06.848079 140285781463488 logging.py:60] Total train batch size (w. parallel, distributed & accumulation) = 1
I1202 14:26:06.848118 140285781463488 logging.py:60] Number of gradient updates per inner epoch = 200
I1202 14:26:06.848156 140285781463488 logging.py:60] Number of inner epochs = 1
I1202 15:19:19.564940 140285781463488 logging.py:60] Saving current state to logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_0
I1202 15:19:19.665072 140285781463488 logging.py:60] Optimizer state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_0/optimizer.bin | 48/200 [53:07<2:30:54, 59.57s/it]
I1202 15:19:19.670379 140285781463488 logging.py:60] Gradient scaler state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_0/scaler.pt
I1202 15:19:19.682694 140285781463488 logging.py:60] Random states saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_0/random_states_0.pkl
I1202 16:18:57.735666 140285781463488 logging.py:60] Saving current state to logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_1
I1202 16:18:57.821295 140285781463488 logging.py:60] Optimizer state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_1/optimizer.bin | 99/200 [1:52:45<2:39:23, 94.68s/it]
I1202 16:18:57.827801 140285781463488 logging.py:60] Gradient scaler state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_1/scaler.pt
I1202 16:18:57.833970 140285781463488 logging.py:60] Random states saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_1/random_states_0.pkl
I1202 17:20:06.462761 140285781463488 logging.py:60] Saving current state to logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_2
I1202 17:20:06.565670 140285781463488 logging.py:60] Optimizer state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_2/optimizer.binβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 148/200 [2:53:54<1:34:02, 108.51s/it]
I1202 17:20:06.572802 140285781463488 logging.py:60] Gradient scaler state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_2/scaler.pt
I1202 17:20:06.578633 140285781463488 logging.py:60] Random states saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_2/random_states_0.pkl
I1202 18:30:18.542086 140285781463488 logging.py:60] Saving current state to logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_3
I1202 18:30:18.639742 140285781463488 logging.py:60] Optimizer state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_3/optimizer.binβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 198/200 [4:04:06<01:59, 59.57s/it]
I1202 18:30:18.645961 140285781463488 logging.py:60] Gradient scaler state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_3/scaler.pt
I1202 18:30:18.650440 140285781463488 logging.py:60] Random states saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_3/random_states_0.pkl
Traceback (most recent call last):
File "/workspace/d3po/scripts/train.py", line 424, in
app.run(main)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/workspace/d3po/scripts/train.py", line 266, in main
assert num_timesteps == config.sample.num_steps
AssertionError

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.