Giter Club home page Giter Club logo

easydel's Introduction

EasyDeL ๐Ÿ”ฎ

Key Features | Latest Updates | Vision | Quick Start | Reference docs | License

EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models, with a primary focus on Jax/Flax. It provides convenient and effective solutions for training and serving Flax/Jax models on TPU/GPU at scale.

Key Features

  • Diverse Architecture Support: Seamlessly work with various model architectures including Transformers, Mamba, RWKV, and more.
  • Diverse Model Support: Implements a wide range of models in JAX, including Falcon, Qwen2, Phi2, Mixtral, Qwen2Moe, Cohere, Dbrx, Phi3, and MPT.
  • Advanced Trainers: Offers specialized trainers like DPOTrainer, ORPOTrainer, SFTTrainer, and VideoCLM Trainer.
  • Serving and API Engines: Provides engines for efficiently serving large language models (LLMs) in JAX.
  • Quantization and Bit Operations: Supports various quantization methods and 8, 6, and 4-bit operations for optimized inference and training.
  • Performance Optimization: Integrates FlashAttention, RingAttention, and other performance-enhancing features.
  • Model Conversion: Supports automatic conversion between JAX-EasyDeL and PyTorch-HF models.

Fully Customizable and Hackable ๐Ÿ› ๏ธ

EasyDeL stands out by providing unparalleled flexibility and transparency:

  • Open Architecture: Every single component of EasyDeL is open for inspection, modification, and customization. There are no black boxes here.

  • Hackability at Its Core: We believe in giving you full control. Whether you want to tweak a small function or completely overhaul a training loop, EasyDeL lets you do it.

  • Custom Code Access: All custom implementations are readily available and well-documented, allowing you to understand, learn from, and modify the internals as needed.

  • Encourage Experimentation: We actively encourage users to experiment, extend, and improve upon the existing codebase. Your innovations could become the next big feature!

  • Community-Driven Development: Share your custom implementations and improvements with the community, fostering a collaborative environment for advancing ML research and development.

With EasyDeL, you're not constrained by rigid frameworks. Instead, you have a flexible, powerful toolkit that adapts to your needs, no matter how unique or specialized they may be. Whether you're conducting cutting-edge research or building production-ready ML systems, EasyDeL provides the freedom to innovate without limitations.

Advanced Customization and Optimization ๐Ÿ”ง

EasyDeL provides unparalleled flexibility in customizing and optimizing your models:

  • Sharding Strategies: Easily customize and experiment with different sharding strategies to optimize performance across multiple devices.

  • Algorithm Customization: Modify and fine-tune algorithms to suit your specific needs and hardware configurations.

  • Attention Mechanisms: Choose from over 10 types of attention mechanisms optimized for GPU/TPU/CPU, including:

    • Flash Attention
    • Blockwise Attention
    • Ring Attention
    • Splash Attention
    • And many more!

This level of customization allows you to squeeze every ounce of performance from your hardware while tailoring the model behavior to your exact requirements.

Future Updates and Vision ๐Ÿš€

EasyDeL is constantly evolving to meet the needs of the machine learning community. In upcoming updates, we plan to introduce:

  • Cutting-Edge: EasyDeL is committed to long-term maintenance and continuous improvement. We provide frequent updates, often on a daily basis, introducing new features, optimizations, and bug fixes. Our goal is to ensure that EasyDeL remains at the cutting edge of machine learning technology, providing researchers and developers with the most up-to-date tools and capabilities.
  • Ready-to-Use Blocks: Pre-configured, optimized building blocks for quick model assembly and experimentation.
  • Enhanced Scalability: Improved tools and methods for effortlessly scaling LLMs to handle larger datasets and more complex tasks.
  • Advanced Customization Options: More flexibility in model architecture and training pipeline customization.

Why Choose EasyDeL?

  1. Flexibility: EasyDeL offers a modular design that allows researchers and developers to easily mix and match components, experiment with different architectures (including Transformers, Mamba, RWKV, and ...), and adapt models to specific use cases.

  2. Performance: Leveraging the power of JAX and Flax, EasyDeL provides high-performance implementations of state-of-the-art models and training techniques, optimized for both TPUs and GPUs.

  3. Scalability: From small experiments to large-scale model training, EasyDeL provides tools and optimizations to efficiently scale your models and workflows.

  4. Ease of Use: Despite its powerful features, EasyDeL maintains an intuitive API, making it accessible for both beginners and experienced practitioners.

  5. Cutting-Edge Research: quickly implementing the latest advancements in model architectures, training techniques, and optimization methods.

Quick Start

Installation

pip install easydel

Testing Attention Mechanisms

import easydel as ed
ed.FlexibleAttentionModule.test_attentions()

Documentation ๐Ÿ’ซ

Comprehensive documentation and examples are available at EasyDeL Documentation.

Here's an improved version of your latest updates:

Latest Updates ๐Ÿ”ฅ

  • EXAONE is added.
  • all of the Models have Architecture improvements.
  • Optimized KeyValueCache:
    • Improved performance for inference
    • Added support for 8bit_cache
  • GenerationPipeline Enhancements:
    • Now supports int8 and nf4 for generation tasks.
  • Enhanced Trainers: Both DPO and ORPO trainers have been upgraded.
  • Simplified Parameter Sharding: You can now shard parameters directly with the model using:
    params = model.shard_params(params)
    params = model.gather_params(params)
  • Training Argument Change: do_shard_params has been removed from TrainArguments. To shard parameters, you must now do so manually before training.
  • DPOTrainer Improvement: Added support for int8 training for reference models.
  • Added ApiEngine and engine_client
  • Improved SFT, DPO, ORPO, CLM Trainers
  • Added support for Gemma2, OLMo models
  • Fixed GPU Flash Attention bugs
  • Improved KV cache quantization accuracy
  • Enhanced memory efficiency for multi-GPU setups

Key Components

GenerationPipeline

The GenerationPipeline class provides a streamlined interface for text generation using pre-trained language models within the JAX framework.

import easydel as ed
from transformers import AutoTokenizer

model, params = ed.AutoEasyDeLModelForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)

pipeline = ed.GenerationPipeline(model=model, params=params, tokenizer=tokenizer)

ApiEngine

ApiEngine is a Serve API Engine for production purposes, providing a stable and efficient API.

import easydel as ed

pipeline = ed.ChatPipeline(...)
engine = ed.ApiEngine(pipeline=pipeline, hostname="0.0.0.0", port=11550)
engine.fire()

EasyDeLState

EasyDeLState acts as a comprehensive container for your EasyDeL model, including training progress, model parameters, and optimizer information.

from easydel import EasyDeLState

state = EasyDeLState.from_pretrained(
    pretrained_model_name_or_path="model_name",
    dtype=jnp.bfloat16,
    param_dtype=jnp.bfloat16,
    sharding_axis_dims=(1, -1, 1, 1)
)

Training Examples

Supervised Fine-Tuning

from easydel import SFTTrainer, TrainArguments

trainer = SFTTrainer(
    arguments=train_arguments,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    formatting_func=prompter,
    packing=True,
    num_of_sequences=max_length,
)

output = trainer.train(flax.core.FrozenDict({"params": params}))

DPO Fine-tuning

from easydel import DPOTrainer

dpo_trainer = DPOTrainer(
    model_state=state,
    ref_model_state=ref_state,
    beta=0.1,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    arguments=arguments,
    max_length=max_length,
    max_target_length=max_target_length,
    max_prompt_length=max_prompt_length,
)

output = dpo_trainer.train()

Contributing

Contributions to EasyDeL are welcome! Please fork the repository, make your changes, and submit a pull request.

License ๐Ÿ“œ

EasyDeL is released under the Apache v2 license. See the LICENSE file for more details.

Contact

If you have any questions or comments about EasyDeL, you can reach out to me at [email protected].

Citation

To cite EasyDeL in your work:

@misc{Zare Chavoshi_2023,
    title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
    url={https://github.com/erfanzar/EasyDeL},
    author={Zare Chavoshi, Erfan},
    year={2023}
}

easydel's People

Contributors

erfanzar avatar s-smits avatar sparsh35 avatar w11wo avatar yhavinga avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

easydel's Issues

How to run llama in the examples

Hi, I am trying to use EasyDel to democratize LLaMa, but I cannot find guide about how to run this model.
Could you please give me some hint about the EasyDel LLaMa launch script, such as repo_id and dataset_name? Thanks!

TypeError: __call__() takes from 2 to 9 positional arguments but 10 were given

Below is the code I used for making train_data

train_data = dataset.map(
        lambda x:tokenizer(generate_prompt(x),max_length=4096,padding='max_length',add_special_tokens=False),
        remove_columns=dataset.column_names,
    )

Error :

  File "/usr/local/lib/python3.8/dist-packages/EasyDel/trainer/fsdp_train.py", line 409, in train
    sharded_train_state_, loss, accuracy = self.sharded_train_step_fn(sharded_train_state_,
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/trainer/fsdp_train.py", line 312, in fsdp_train_step_
    (loss__, accuracy__), grad = grad_fn(state.params)
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/trainer/fsdp_train.py", line 303, in calculate_loss
    logits = state.apply_fn(params=params, **batch,
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 565, in __call__
    outputs = self.module.apply(
TypeError: __call__() takes from 2 to 9 positional arguments but 10 were given

and also there is a o_proj error in mistral

Resume from checkpoint

When I load model from a checkpoint and try to continue training at the previous learning rate, the loss increases sharply after a small step, so I have to warm up. Is this normal?

What is the hardware spec are you using to tran a LLAMA model with 7B params

I am frequently getting Out of Memory error. I am using v2.8 TPU. TPU v3.8 is generally not available, however, both have same similar specs 8 cores and 8GB memory.

Total hbm usage >= 9.50G:
    reserved        530.00M 
    program           8.98G 
    arguments            0B 

Output size 0B; shares 0B with arguments.

Program hbm requirement 8.98G:
    global           241.0K
    scoped            9.56M
    HLO temp          8.97G (100.0% utilization: Unpadded (8.94G) Padded (8.94G), 0.4% fragmentation (36.60M))

  Largest program allocations in hbm:


ValueError: `params` cannot be accessed from model when the model is created with `_do_init=False`.

Describe the bug
Getting the following error while running llama model after training using EasyDel and converting to Hugging face.

python serve_llama_tpu_easydel.py 
Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 3/3 [00:03<00:00,  1.14s/it]
Traceback (most recent call last):
  File "/home/XXX/research/EasyDeL/serve_llama_tpu_easydel.py", line 26, in <module>
    params=model.params,
  File "/home/XXX/research/EasyDeL/.venv/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 271, in params
    raise ValueError(
ValueError: `params` cannot be accessed from model when the model is created with `_do_init=False`. You must call `init_weights` manually and store the params outside of the model and pass it explicitly where needed.

To Reproduce

  • Using Kaggle to traiin llama-2-7b-chat-hf model on certain data
  • convert and publish to Hugging face
  • load the model using EasyDel example

Help train on tpu v3-32

#34
I read this issue and tried it. but couldn't make it work :(

Hi, Thank you for your amazing work.

I've been trying few days to make tpu v3-32 to work.

I used tpu VM 'tpu-ubuntu2204-base' and tried by following code after installing jax and etc to each tpus

train.py


import jax.numpy
from EasyDel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

huggingface_repo_id_or_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 2048
tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="my_first_model_to_train_using_easydel",
    num_train_epochs=3,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=64,
    max_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in fully FSDP automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_remat="",
    dtype=jnp.bfloat16
)


def ultra_chat_prompting_process(
        data_chunk
):
    user_part = [
        chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user"
    ]
    assistant_part = [
        chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant"
    ]

    prompt = ""

    for uc, ac in zip(user_part, assistant_part):
        prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"

    return {"prompt": prompt}


tokenization_process = lambda data_chunk: tokenizer(
    data_chunk["prompt"],
    add_special_tokens=False,
    max_length=max_length,
    padding="max_length"
)

dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
    tokenization_process,
    num_proc=12,
    remove_columns=dataset_train.column_names
)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.last_save_file_name}")

Then I sent it to tpus by
sudo gcloud compute tpus tpu-vm scp train.py node-1: --worker=all --zone=europe-west4-a

and ran it
sudo gcloud compute tpus tpu-vm ssh node-1 --zone=europe-west4-a --worker=all --command="python3 train.py"

and got error

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
/usr/local/lib/python3.10/dist-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.
  table = cls._concat_blocks(blocks, axis=0)
/usr/local/lib/python3.10/dist-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.
  table = cls._concat_blocks(blocks, axis=0)
/usr/local/lib/python3.10/dist-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.
  table = cls._concat_blocks(blocks, axis=0)
/usr/local/lib/python3.10/dist-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.
  table = cls._concat_blocks(blocks, axis=0)
wandb: Currently logged in as: cyine. Use `wandb login --relogin` to force relogin
wandb: Currently logged in as: cyine. Use `wandb login --relogin` to force relogin
wandb: Currently logged in as: cyine. Use `wandb login --relogin` to force relogin
wandb: Currently logged in as: cyine. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.1
wandb: Run data is saved locally in /root/wandb/run-20240104_234925-ri9ge2oc
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run vibrant-sound-5
wandb: โญ๏ธ View project at https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel
wandb: ๐Ÿš€ View run at https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel/runs/ri9ge2oc
Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function
Time For configure dataloaders (ms) : 0.2741813659667969
Traceback (most recent call last):
  File "/root/train.py", line 94, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 233, in __init__
    self.init_functions()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 305, in init_functions
    self.model, self.tx, self.scheduler, self.config = self.configure_model()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 381, in configure_model
    if not hasattr(self.arguments.configs_to_init_model_class["config"], "get_partition_rules"):
TypeError: 'NoneType' object is not subscriptable
Traceback (most recent call last):
  File "/root/train.py", line 94, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 233, in __init__
    self.init_functions()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 305, in init_functions
    self.model, self.tx, self.scheduler, self.config = self.configure_model()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 381, in configure_model
    if not hasattr(self.arguments.configs_to_init_model_class["config"], "get_partition_rules"):
TypeError: 'NoneType' object is not subscriptable
wandb: Tracking run with wandb version 0.16.1
wandb: Run data is saved locally in /root/wandb/run-20240104_234925-enyk3wwr
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run vital-sun-5
wandb: โญ๏ธ View project at https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel
wandb: ๐Ÿš€ View run at https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel/runs/enyk3wwr
Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function
Time For configure dataloaders (ms) : 0.27751922607421875
Traceback (most recent call last):
  File "/root/train.py", line 94, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 233, in __init__
    self.init_functions()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 305, in init_functions
    self.model, self.tx, self.scheduler, self.config = self.configure_model()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 381, in configure_model
    if not hasattr(self.arguments.configs_to_init_model_class["config"], "get_partition_rules"):
TypeError: 'NoneType' object is not subscriptable
Traceback (most recent call last):
  File "/root/train.py", line 94, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 233, in __init__
    self.init_functions()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 305, in init_functions
    self.model, self.tx, self.scheduler, self.config = self.configure_model()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 381, in configure_model
    if not hasattr(self.arguments.configs_to_init_model_class["config"], "get_partition_rules"):
TypeError: 'NoneType' object is not subscriptable
wandb: Tracking run with wandb version 0.16.1
wandb: Run data is saved locally in /root/wandb/run-20240104_234925-wvn0sguc
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run northern-vortex-5
wandb: โญ๏ธ View project at https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel
wandb: ๐Ÿš€ View run at https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel/runs/wvn0sguc
Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function
Time For configure dataloaders (ms) : 0.29778480529785156
Traceback (most recent call last):
  File "/root/train.py", line 94, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 233, in __init__
    self.init_functions()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 305, in init_functions
    self.model, self.tx, self.scheduler, self.config = self.configure_model()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 381, in configure_model
    if not hasattr(self.arguments.configs_to_init_model_class["config"], "get_partition_rules"):
TypeError: 'NoneType' object is not subscriptable
Traceback (most recent call last):
  File "/root/train.py", line 94, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 233, in __init__
    self.init_functions()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 305, in init_functions
    self.model, self.tx, self.scheduler, self.config = self.configure_model()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 381, in configure_model
    if not hasattr(self.arguments.configs_to_init_model_class["config"], "get_partition_rules"):
TypeError: 'NoneType' object is not subscriptable
wandb: Tracking run with wandb version 0.16.1
wandb: Run data is saved locally in /root/wandb/run-20240104_234925-cqws9qy2
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run dazzling-sea-8
wandb: โญ๏ธ View project at https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel
wandb: ๐Ÿš€ View run at https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel/runs/cqws9qy2
Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function
Time For configure dataloaders (ms) : 0.278472900390625
Traceback (most recent call last):
  File "/root/train.py", line 94, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 233, in __init__
    self.init_functions()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 305, in init_functions
    self.model, self.tx, self.scheduler, self.config = self.configure_model()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 381, in configure_model
    if not hasattr(self.arguments.configs_to_init_model_class["config"], "get_partition_rules"):
TypeError: 'NoneType' object is not subscriptable
Traceback (most recent call last):
  File "/root/train.py", line 94, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 233, in __init__
    self.init_functions()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 305, in init_functions
    self.model, self.tx, self.scheduler, self.config = self.configure_model()
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/causal_language_model_trainer.py", line 381, in configure_model
    if not hasattr(self.arguments.configs_to_init_model_class["config"], "get_partition_rules"):
TypeError: 'NoneType' object is not subscriptable
wandb: ๐Ÿš€ View run vibrant-sound-5 at: https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel/runs/ri9ge2oc
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240104_234925-ri9ge2oc/logs
wandb: ๐Ÿš€ View run northern-vortex-5 at: https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel/runs/wvn0sguc
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240104_234925-wvn0sguc/logs
wandb: ๐Ÿš€ View run vital-sun-5 at: https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel/runs/enyk3wwr
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240104_234925-enyk3wwr/logs
wandb: ๐Ÿš€ View run dazzling-sea-8 at: https://wandb.ai/cyine/easydel-my_first_model_to_train_using_easydel/runs/cqws9qy2
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240104_234925-cqws9qy2/logs
##### Command execution on worker 0 failed with exit status 1. Continuing.
##### Command execution on worker 3 failed with exit status 1. Continuing.
##### Command execution on worker 1 failed with exit status 1. Continuing.
##### Command execution on worker 2 failed with exit status 1. Continuing.

The issue is

  1. The training failed.
  2. It seems to be work not parallel but on each pod.

Thank you,

GPT2 (150M model) support on Tv2.8. Example scripts goes out of memory

Describe the bug
Out of memory for a smaller gpt2 model with 150M params

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 9.35G of 7.48G hbm. Exceeded hbm capacity by 1.86G.

Total hbm usage >= 9.86G:
    reserved        530.00M 
    program           9.35G 
    arguments            0B 

Output size 0B; shares 0B with arguments.

Program hbm requirement 9.35G:
    HLO temp          9.35G (3.1% utilization: Unpadded (294.48M) Padded (9.35G), 0.0% fragmentation (30.0K))

  Largest program allocations in hbm:

  1. Size: 9.20G
     Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/erf_inv" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
     Shape: f32[19298688,2]{1,0:T(8,128)}
     Unpadded size: 147.24M
     Extra memory due to padding: 9.06G (64.0x expansion)
     XLA label: copy.1 = copy(fusion.7)
     Allocation type: HLO temp
     ==========================

  2. Size: 147.38M
     Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/erf_inv" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
     Shape: f32[768,50257]{1,0:T(8,128)}
     Unpadded size: 147.24M
     Extra memory due to padding: 141.0K (1.0x expansion)
     XLA label: reshape.84 = reshape(copy.1)
     Allocation type: HLO temp
     ==========================

  3. Size: 1.0K
     Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/slice[start_indices=(1,) limit_indices=(2,) strides=(1,)]" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
     Shape: (u32[1]{0:T(256)}, u32[1]{0:T(256)})
     Unpadded size: 1.0K
     XLA label: fusion.53 = fusion(Arg_0.1), kind=kLoop, calls=fused_computation.52
     Allocation type: HLO temp
     ==========================

  4. Size: 1.0K
     Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/add" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
     Shape: (u32[9649344]{0:T(1024)}, u32[9649344]{0:T(1024)})
     Unpadded size: 1.0K
     XLA label: fusion.50 = fusion(xor.27, bitcast.1, bitcast), kind=kLoop, calls=fused_computation.50
     Allocation type: HLO temp
     ==========================

To Reproduce

Install deps

 pip install EasyDeL@git+https://github.com/erfanzar/EasyDeL.git@main
 pip install jax[tpu]==0.4.22 -f https://storage.googleapis.com/libtpu-releases/index.html

Use the following code

from EasyDel.modules import AutoEasyDelModelForCausalLM
from EasyDel.serve import JAXServer
from transformers import AutoTokenizer
import jax

model_huggingface_repo_id = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_huggingface_repo_id, trust_remote_code=True)
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
    model_huggingface_repo_id,
    jax.devices("cpu")[0],
    jax.numpy.float16,
    jax.numpy.float16,
    jax.lax.Precision("fastest"),
    (1, -1, 1, 1),
    device_map="auto"
)

params = model.init_weights(jax.random.PRNGKey(0), (1, 1))
server = JAXServer.from_parameters(
    model=model,
    config_model=model.config,
    tokenizer=tokenizer,
    params=params,
    add_params_field=True
)

response_printed = 0
for response, tokens_used in server.process(
        "String To The Model", stream=True
):
    print(response[response_printed:], end="")
    response_printed = len(response)

Segmentation fault (core dumped)

When trying to train Mistral on a v2-8 TPU, the

Segmentation fault (core dumped)

error occurs. Do you have any idea how to resolve this?

While training Gpt2 model - Exception - TypeError: in_shardings leaf specifications are expected to be PartitionSpec instances or None, but got *

Describe the bug
While I was trying to finetune a GPT2 model and even another non-Llama model, I get the following exception. Am I missing something?

To Reproduce
Take a GPT2 model and train on some data
The following exception will be raised

Downloading data files: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00, 9868.95it/s]
Extracting data files: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00, 1449.81it/s]
Generating train split: 2942 examples [00:00, 241149.94 examples/s]
Map: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2942/2942 [00:00<00:00, 20119.71 examples/s]
Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose "Don't visualize my results"
wandb: Tracking run with wandb version 0.16.2
wandb: W&B syncing is set to `offline` in this directory.  
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
Time Took to Complete Task configure dataloaders (microseconds) : 0.20360946655273438
Time Took to Complete Task configure Model ,Optimizer ,Scheduler and Config (microseconds) : 654.585599899292
Traceback (most recent call last):
  File "/home/***/research/EasyDeL/train_ravengpt.py", line 74, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 244, in __init__
    self.init_functions()
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 310, in init_functions
    funcs = self.configure_functions()
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 477, in configure_functions
    create_sharded_state_from_params_fn = pjit(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 762, in pjit
    static_argnames) = pre_infer_params(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 309, in pre_infer_params
    in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 1093, in prepare_axis_resources
    new_entries.append(ParsedPartitionSpec.from_user_input(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 985, in from_user_input
    raise TypeError(f"{arg_name} are expected to be "
TypeError: in_shardings leaf specifications are expected to be PartitionSpec instances or None, but got *
Traceback (most recent call last):
  File "/home/***/research/EasyDeL/train_ravengpt.py", line 74, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 244, in __init__
    self.init_functions()
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 310, in init_functions
    funcs = self.configure_functions()
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 477, in configure_functions
    create_sharded_state_from_params_fn = pjit(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 762, in pjit
    static_argnames) = pre_infer_params(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 309, in pre_infer_params
    in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 1093, in prepare_axis_resources
    new_entries.append(ParsedPartitionSpec.from_user_input(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 985, in from_user_input
    raise TypeError(f"{arg_name} are expected to be "
TypeError: in_shardings leaf specifications are expected to be PartitionSpec instances or None, but got *
wandb: You can sync this run to the cloud by running:

Test Code

import json
import jax.numpy
from EasyDel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

huggingface_repo_id_or_path = "gpt2"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 256
tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="raven_gpt_using_easydel",
    num_train_epochs=3,
    configs_to_init_model_class=configs_to_init_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=2,
    max_steps=100,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in fully FSDP automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)


with open("tmp_json.jsonl", "w") as tmp_json:
    for line in open("data.jsonl", "r"):
        record = json.loads(line.strip())
        turn_resp = "<|input|>" + record["prompt"] + "<|response|>" + record["response"]
        tmp_json.write(json.dumps({"turn_resp": turn_resp}))

data = load_dataset("json", data_files="tmp_json.jsonl", split="train")
data = data.map(lambda samples: tokenizer(samples["turn_resp"]), batched=True)
data = data.train_test_split(test_size=0.001)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    data,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

AMD Hardware Support

Not exactly a bug, but I'm about to try to get EasyDel to work on some AMD GPU servers I've got, and might need some help. Would it be possible to pay for support to get EasyDel working on these servers?

Potential regression causing resource exhausted after recent commit

I noticed latest HEAD could not hold in memory anymore a Mistral 7B training on tpu-v3-8, sequence length was set to 2048
Git bisect indicated that the commit below was the first with RESOURCE_EXHAUSTED: XLA:TPU

Author: erfanzar <[email protected]>
Date:   Wed Jan 31 12:02:38 2024 -0800

    Working on `LoRA`

The preceding commit does not exhibit the resource exhausted issue.

Author: Erfan Zare Chavoshi <[email protected]>
Date:   Wed Jan 31 18:33:47 2024 +0330

    Delete python_test/EasyDel-Checkpoints/Lora-Test directory

I instantiate the trainarguments like this (I've incorporated the change from max_length to max_sequence_length) :

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model, params = AutoEasyDelModelForCausalLM.from_pretrained(model_id)
config = model.config
config.freq_max_position_embeddings = config.max_position_embeddings  # 32768
config.max_position_embeddings = 2048  # Let use context length of ... for training
config.c_max_position_embeddings = config.max_position_embeddings

import pprint
pprint.pprint(config.get_partition_rules(True))

max_sequence_length = config.max_position_embeddings

train_args = TrainArguments(
    model_class=EasyDel.FlaxMistralForCausalLM,
    configs_to_initialize_model_class={
        'config': config,
        'dtype': jnp.bfloat16,
        'param_dtype': jnp.bfloat16,
        'input_shape': (1, 1)
    },
    custom_rule=config.get_partition_rules(True),
    model_name='MistralBoreas',
    num_train_epochs=1,
    learning_rate=1e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,
    scheduler=EasyDelSchedulers.WARM_UP_LINEAR,
    warmup_steps=500,
    weight_decay=0.01,
    total_batch_size=1,
    save_steps=2000,
    do_train=True,
    do_eval=False,
    backend='tpu',
    max_sequence_length=max_sequence_length,
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),
    use_pjit_attention_force=False,
    use_flash_attention=False,
    gradient_accumulation_steps=8,
    remove_ckpt_after_load=True,
    ids_to_pop_from_dataset=['token_type_ids'],
    loss_remat="",
    dtype=jnp.bfloat16
)

trainer = CausalLanguageModelTrainer(
    train_args,
    dataset_train,
    checkpoint_path=None
)

Error while running a GPT2 model

An exception is generated while running a gpt2 format model as show below

To Reproduce

Prepare to serve a model

run examples/serving/causal-lm/artgpt2tox-chat.py --pretrained_model_name_or_path="leondz/artgpt2tox" --max_length=2048   --max_new_tokens=256 --max_compile_tokens=32 --temperature=0.6   --top_p=0.95 --top_k=50   --dtype="fp16" --use_prefix_tokenizer --mesh_axes_shape 1 -1 1 1

Getting following exception

File ~/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py:428, in JAXServer.from_torch_pretrained(cls, server_config, pretrained_model_name_or_path, device, dtype, param_dtype, precision, sharding_axis_dims, sharding_axis_names, q_ps, k_ps, v_ps, b_ps, a_ps, use_shard_map, input_shape, shard_fns, backend, add_params_field, do_memory_log, verbose, **kwargs)
    382 @classmethod
    383 def from_torch_pretrained(
    384         cls,
   (...)
    405         **kwargs
    406 ):
    408     model, params = AutoEasyDelModelForCausalLM.from_pretrained(
    409         pretrained_model_name_or_path=pretrained_model_name_or_path,
    410         device=device,
   (...)
    425         **kwargs
    426     )
--> 428     return cls.from_parameters(
    429         model=model,
    430         config_model=model.config,
    431         tokenizer=transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path),
    432         params=params,
    433         config=server_config,
    434         verbose=verbose,
    435         do_memory_log=do_memory_log,
    436         add_params_field=add_params_field
    437     )

File ~/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py:486, in JAXServer.from_parameters(cls, model, config_model, tokenizer, params, config, add_params_field, do_memory_log, verbose)
    482 logging.info(
    483     "matching partition rules"
    484 )
    485 rules = match_partition_rules(params=params, rules=config_model.get_partition_rules(True))
--> 486 shard_fns, _ = make_shard_and_gather_fns(rules, get_dtype(server.config.dtype))
    487 logging.info(
    488     "sharding parameters across all of the chosen backend(tpu/gpu/cpu)s"
    489 )
    490 params = flax.traverse_util.flatten_dict(params)

File ~/research/EasyDeL/.venv/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py:66, in make_shard_and_gather_fns(partition_specs, dtype_specs)
     63     return gather_fn
     65 if dtype_specs is None or dtype_specs in float_dtypes:
---> 66     shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
     67     gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs)
     68 else:

File ~/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/tree_util.py:244, in tree_map(f, tree, is_leaf, *rest)
    242 leaves, treedef = tree_flatten(tree, is_leaf)
    243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/tree_util.py:244, in <genexpr>(.0)
    242 leaves, treedef = tree_flatten(tree, is_leaf)
    243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/research/EasyDeL/.venv/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py:42, in make_shard_and_gather_fns.<locals>.make_shard_fn(partition_spec, dtype_spec)
     41 def make_shard_fn(partition_spec, dtype_spec=None):
---> 42     jax_shard_function = pjit(
     43         make_to_dtype_fn(dtype_spec),
     44         in_shardings=None,
     45         out_shardings=partition_spec
     46     )
     48     def shard_fn(tensor):
     49         return jax_shard_function(tensor).block_until_ready()

File ~/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:762, in pjit(fun, in_shardings, out_shardings, static_argnums, static_argnames, donate_argnums, donate_argnames, keep_unused, device, backend, inline, abstracted_axes)
    570 def pjit(
    571     fun: Callable,
    572     in_shardings=UNSPECIFIED,
   (...)
    582     abstracted_axes: Any | None = None,
    583 ) -> stages.Wrapped:
    584   """Makes ``fun`` compiled and automatically partitioned across multiple devices.
    585 
    586   NOTE: This function is now equivalent to jax.jit please use that instead.
   (...)
    759   [ 0.5  2.   4.   6.   8.  10.  12.  10. ]
    760   """
    761   (in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
--> 762    static_argnames) = pre_infer_params(
    763        fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
    764        static_argnums, static_argnames, device, backend, abstracted_axes)
    766   def infer_params(*args, **kwargs):
    767     # Putting this outside of wrapped would make resources lexically scoped
    768     resource_env = mesh_lib.thread_resources.env

File ~/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:310, in pre_infer_params(fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes)
    307   in_shardings = tuple(in_shardings)
    309 in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
--> 310 out_shardings, _, _ = prepare_axis_resources(out_shardings, 'out_shardings')
    312 donate_argnums, donate_argnames, static_argnums, static_argnames = resolve_argnums(
    313     fun, donate_argnums, donate_argnames, static_argnums, static_argnames)
    315 return (in_shardings, out_shardings, donate_argnums, donate_argnames,
    316         static_argnums, static_argnames)

File ~/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py:1093, in prepare_axis_resources(axis_resources, arg_name, allow_unconstrained_dims)
   1091     new_entries.append(entry)
   1092   else:
-> 1093     new_entries.append(ParsedPartitionSpec.from_user_input(
   1094         entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
   1096 _check_unique_resources(new_entries, arg_name)
   1097 return tree_util.tree_unflatten(treedef, new_entries), new_entries, treedef

File ~/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py:985, in ParsedPartitionSpec.from_user_input(cls, entry, arg_name, allow_unconstrained_dims)
    983   return cls(entry, ())
    984 if not isinstance(entry, PartitionSpec):
--> 985   raise TypeError(f"{arg_name} are expected to be "
    986                   f"PartitionSpec instances or None, but got {entry}")
    987 axis_specs = []
    988 for axis_spec in entry:

TypeError: out_shardings leaf specifications are expected to be PartitionSpec instances or None, but got *

artgpt code

import typing

import termcolor

import EasyDel
import jax.lax
from EasyDel.serve import JAXServer, JAXServerConfig
from fjformer.checkpoint import get_dtype
from transformers import AutoTokenizer
import argparse

DEFAULT_SYSTEM_PROMPT = ""


def get_prompt_llama2_format(message: str, chat_history,
                             system_prompt: str) -> str:
    texts = [f'<|input|>{message}<|response|>']
    return "".join(texts)


class Attack2Host(JAXServer):
    def __init__(self, config=None):
        super().__init__(config=config)

    @staticmethod
    def format_instruct(system: str, instruction: str) -> str:
        return get_prompt_llama2_format(instruction, [], system)

    @staticmethod
    def format_chat(history: typing.List[str], prompt: str, system: typing.Union[str, None]) -> str:
        return get_prompt_llama2_format(prompt, history, system)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Argument parser for Llama2.")
    parser.add_argument(
        '--pretrained_model_name_or_path',
        default='meta-llama/Llama-2-7b-chat-hf',
        help='HuggingFace Repo to load model From'
    )
    parser.add_argument(
        "--contains_auto_format",
        default=False,
        action="store_true",
        help="Whether the input text contains auto-format tokens.",
    )
    parser.add_argument(
        "--max_length",
        default=4096,
        type=int,
        help="The maximum length of the input text.",
    )
    parser.add_argument(
        "--max_new_tokens",
        default=2048,
        type=int,
        help="The maximum number of new tokens to generate.",
    )
    parser.add_argument(
        "--max_compile_tokens",
        default=32,
        type=int,
        help="The maximum number of tokens to generate per stream.",
    )
    parser.add_argument(
        "--temperature",
        default=0.6,
        type=float,
        help="The temperature of the sampling distribution.",
    )
    parser.add_argument(
        "--top_p",
        default=0.95,
        type=float,
        help="The top-p probability cutoff for the sampling distribution.",
    )
    parser.add_argument(
        "--top_k",
        default=50,
        type=int,
        help="The top-k number of tokens to keep for the sampling distribution.",
    )
    parser.add_argument(
        "--logging",
        default=False,
        action="store_true",
        help="Whether to log the generation process.",
    )
    parser.add_argument(
        "--mesh_axes_names",
        default=["dp", "fsdp", "tp", "sp"],
        nargs="+",
        help="The names of the mesh axes.",
    )
    parser.add_argument(
        "--mesh_axes_shape",
        default=[1, -1, 1, 1],
        nargs="+",
        type=int,
        help="The shapes of the mesh axes.",
    )
    parser.add_argument(
        "--dtype",
        default="fp16",
        help="The data type to use for the generation.",
    )
    parser.add_argument(
        "--use_prefix_tokenizer",
        default=False,
        action="store_true",
        help="Whether to use a prefix tokenizer.",
    )
    args = parser.parse_args()
    configs = JAXServerConfig(
        contains_auto_format=args.contains_auto_format,
        max_length=args.max_length,
        max_new_tokens=args.max_new_tokens,
        max_compile_tokens=args.max_compile_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k,
        logging=args.logging,
        mesh_axes_names=args.mesh_axes_names,
        mesh_axes_shape=args.mesh_axes_shape,
        dtype=args.dtype,
        use_prefix_tokenizer=args.use_prefix_tokenizer
    )

    server = Attack2Host.from_torch_pretrained(
        pretrained_model_name_or_path=args.pretrained_model_name_or_path,
        server_config=configs
    )
    try:
        termcolor.cprint(
            'Launching App ...',
            color="cyan",
            force_color=True
        )
        server.gradio_inference().launch(share=True)
        termcolor.cprint(
            'Launching Server APIS (Fire) ...',
            color="cyan",
            force_color=True
        )
        server.fire()
    except KeyboardInterrupt:
        print('Exiting ...')
        server.end()
        exit(0)

Error running remote model that has custom code

ValueError: Loading this model requires you to execute custom code contained in the model repository on your local machine. Please set the option `trust_remote_code=True` to permit loading of this model.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/xxx/research/transformers/train_phi2_easydel.py", line 21, in <module>
    model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, trust_remote_code=True)
  File "/home/xxx/research/EasyDeL/lib/python/EasyDel/modules/auto_easydel_model.py", line 221, in from_pretrained
    config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/transformers/models/auto/configuration_auto.py", line 1085, in from_pretrained
    trust_remote_code = resolve_trust_remote_code(
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/transformers/dynamic_module_utils.py", line 611, in resolve_trust_remote_code
    raise ValueError(
ValueError: The repository for stabilityai/stable-code-3b contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/stabilityai/stable-code-3b.


Llama RoPE bug

In your code, default rotary type is 'lm2' but generation output is strange.

model is meta-llama/Llama-2-7b-hf

rope type = 'lm2'

<s>Hi, how are you doing? ๐Ÿ™‚๏ฟฝ\nI'm from South Korea, ๐Ÿ˜๐Ÿ˜๏ฟฝ๏ฟฝ\nHi,\nI am a ๏ฟฝ๏ฟฝฬ…๏ฟฝ\nI am a friendly and I'm new to you?\nI'm looking

rope type = 'complex'

<s>Hi, how are you doing? ๐Ÿ™‚\nMy name is Jasmine and I am the creator of this blog.\nMy name is Jasmine, Iโ€™m 23319 years old, Iโ€™m from Germany.Hi, my name is Jasmine

This is my test code. I changed parameter names of llama model to use huggingface from_pt=True
When I used converted weights using converter.py, generation output was also strange.

from EasyDel.modules.llama.modelling_llama_flax import (
    FlaxLlamaForCausalLM
)
from transformers import LlamaTokenizer, LlamaForCausalLM
import jax
import torch

tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

device = jax.devices('cpu')[0]
with jax.default_device(device):
    print("loading")
    model = FlaxLlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", from_pt=True)

    print("generating")
    ids = tokenizer(["Hi, how are you doing? "], return_tensors="np")
    outs = model.generate(
        ids["input_ids"],
        do_sample=True,
        max_length=64,
        params={"params": model.params}
    )

    print(outs)
    print(tokenizer.batch_decode(outs.sequences))

The model generate repeated words.

Describe the bug

I use the example code provided on the documentation (https://erfanzar.github.io/EasyDeL/Llama2/).
But I changed the model to 'hfl/chinese-alpaca-2-7b-16k'. (The model is based on llama-2-chat and has a linear rope scale in config.json)

Then when I run the code. the model repeated to generate
โ€œSharding Params: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 291/291 [02:03<00:00, 2.36it/s]
Compiling NonGreedy(Generate) Funcs
. Beijing. Beijing in English.ing in Chinese.ing Beijing Beijing Beijing Beijing Beijing Beijing Beijing Beijing Beijing Beijing Beijing, China Daily Beijing Beijing. Beijing Beijinginginginginginginginginginginginginginginginging Beijinginging Beijing Beijing Beijing Beijing Beijing Beijinging Beijingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingingiโ€

How can I solve above issue? Thank you so much. this is really an amazing work.

To Reproduce
Below is my Full code.

`from EasyDel.modules import FlaxLlamaForCausalLM
from EasyDel.serve import JAXServer
import jax
from EasyDel.transform import llama_from_pretrained
from transformers import AutoTokenizer

model_id = 'hfl/chinese-alpaca-2-7b-16k'

params, config = llama_from_pretrained(
model_id,
device=jax.devices('cpu')[0] # Offload on CPU
)

server = JAXServer.load_from_params(
params=params,
model=FlaxLlamaForCausalLM(
config=config,
dtype=jax.numpy.bfloat16, # Im on TPUs
param_dtype=jax.numpy.bfloat16, # Im on TPUs
precision=jax.lax.Precision('fastest'),
_do_init=False,
input_shape=(1, 1024)
),
config_model=config,
add_params_field=True,
tokenizer=AutoTokenizer.from_pretrained(model_id),
verbose=False,
do_memory_log=True,
)

response_printed = 0
for response, tokens_used in server.process(
'Please introduce Beijing in detail', stream=True
):
print(response[response_printed:], end='')
response_printed = len(response)`

Question: Does low-bit config reduce TPU HBM memory usage when training?

Hello,

Firstly, I'd like to express my appreciation for your work on this repository. I noticed that it supports low-bit (4 or 8 bits) formats during training, which is quite intriguing.

I have a query regarding TPU compatibility, particularly before TPUv4. As TPUs typically don't support low-bit formats like 4 or 8 bits until TPUv4 (which supports int8), I'm curious about how this implementation works. My current understanding is that the code might be converting 4 or 8-bit formats into bfloat16 or float16 formats. If this is the case, would it imply that the memory usage reduction typically expected from lower bit formats might not be realized?

Could you please clarify if my understanding is correct? Thanks for your time and effort in developing and maintaining this repository.

Error while training GPT2 on the kaggle

Describe the bug
Error while training gpt2 on kaggle

/root
Downloading base model...
<class 'EasyDel.modules.gpt2.modelling_gpt2_flax.FlaxGPT2LMHeadModel'>
<class 'EasyDel.modules.gpt2.gpt2_configuration.GPT2Config'>
Downloading data files: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00, 11214.72it/s]
Extracting data files: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00, 1438.38it/s]
Generating train split: 186074 examples [00:00, 353355.15 examples/s]
Map (num_proc=12): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 186074/186074 [00:03<00:00, 47310.49 examples/s]
/usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
  table = cls._concat_blocks(blocks, axis=0)
Map (num_proc=12):   0%|                      | 0/186074 [00:00<?, ? examples/s]/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
Map (num_proc=12): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 186074/186074 [00:21<00:00, 8523.98 examples/s]
Warning : In case of using `finetune = True` and Passing `checkpoint_path = None` you should pass parameters in train function
wandb: Currently logged in as: jchauhan (safedep). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.2
wandb: Run data is saved locally in /root/wandb/run-20240201_154611-g3pguwrw
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run avid-sky-14
wandb: โญ๏ธ View project at [https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel](https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel%3C/span%3E)
wandb: ๐Ÿš€ View run at [https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel/runs/g3pguwrw](https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel/runs/g3pguwrw%3C/span%3E)
Time Took to Complete Task configure dataloaders (microseconds) : 0.4191398620605469
Time Took to Complete Task configure Model ,Optimizer ,Scheduler and Config (microseconds) : 597.6324081420898
Time Took to Complete Task configure functions and sharding them (microseconds) : 745.0320720672607
Action : Sharding Passed Parameters
Traceback (most recent call last):
  File "/root/train.py", line 123, in <module>
    output = trainer.train(flax.core.FrozenDict({"params": params}))
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 478, in train
    sharded_state, shard_fns, gather_fns = self.initialize_state(
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 393, in initialize_state
    params = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 243, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 243, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Dict key mismatch; expected keys: ['transformer']; dict: {'transformer': {'wte': {'embedding':

To Reproduce

%%writefile /root/train.py

import os
import jax.numpy
import EasyDel
from EasyDel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wand_key = user_secrets.get_secret("WAND_KEY")
# os.environ["WANDB_DISABLED"] = "false"
os.environ["WANDB_API_KEY"] = wand_key

base_model_hf_repo_id_or_path = "gpt2"
max_length = 1024
trained_model_name = "****"
trained_model_hf_repo_id = f"****/{trained_model_name}"
easydel_trained_model_name = f"{trained_model_name}.easydel"
training_data_files="****"


import json
import sys
jcdataset = load_dataset('****', split='train')
f = open("./lmsys-toxic-gpt.json", "w")
for conversation in jcdataset['chunks']:
    out = "<|input|><|response|>"
    for req_res in conversation:
        out = out + req_res['prompt']
        f.write(json.dumps({'train': out}))
        f.write("\n")
        out = "<|input|>" + req_res['response']  +"<|response|>"


print("Downloading base model...")
model, params = AutoEasyDelModelForCausalLM.from_pretrained(base_model_hf_repo_id_or_path, trust_remote_code=True)

tokenizer = AutoTokenizer.from_pretrained(
    base_model_hf_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

model.config.use_sacn_mlp = False

print(type(model))
print(type(model.config))

train_arguments = TrainArguments(
    model_class=type(model),
    model_name=easydel_trained_model_name,
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_init_model_class,
    custom_rule=model.config.get_partition_rules(True),
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=8,
    max_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in fully FSDP automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)

def ultra_chat_prompting_process(
        data_chunk
):
    return {"prompt": data_chunk['train']}


tokenization_process = lambda data_chunk: tokenizer(
    data_chunk["prompt"],
    add_special_tokens=False,
    max_length=max_length,
    padding="max_length"
)

dataset = load_dataset("json", data_files=training_data_files)
dataset_train = dataset["train"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
    tokenization_process,
    num_proc=12,
    remove_columns=dataset_train.column_names
)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")


import tempfile
import os
from huggingface_hub import Repository, create_repo
from transformers import LlamaForCausalLM, LlamaTokenizer
import jax

from EasyDel import (
    AutoEasyDelConfig,
    EasyDelState,
    easystate_to_huggingface_model
)


# Function to create a Hugging Face repository
def create_hf_repo(repo_name, hub_token=None):
    tmp_dir = tempfile.TemporaryDirectory()
    tmp_output_dir = tmp_dir.name

    if repo_name is None:
        repo_name = os.path.basename(tmp_output_dir)
    
    # Create repo and retrieve repo_id
    repo_id = create_repo(repo_name, exist_ok=True, token=hub_token).repo_id
    
    # Clone repo locally
    repo = Repository(tmp_output_dir, clone_from=repo_id, token=hub_token)
    tmp_dir.cleanup()

    return repo

# Define the base model ID, checkpoint path, and target Hugging Face repo ID
chkpoint_path = output.checkpoint_path

# Load configuration for the custom model
config = AutoEasyDelConfig.from_pretrained(base_model_hf_repo_id_or_path)

# Create the custom model using EasyDel
with jax.default_device(jax.devices("cpu")[0]):
    model = easystate_to_huggingface_model(
        state=EasyDelState.load_state(chkpoint_path),
        base_huggingface_module=LlamaForCausalLM,
        config=config
    )
# 
model = model.half()  # Convert to a Hugging Face model

# Check if the target Hugging Face repo exists, and create it if not
hub_token = None # login is already done
# repo = create_hf_repo(trained_model_hf_repo_id, hub_token)


# Optionally, you can push the base model to the target repo as well
base_model = LlamaForCausalLM.from_pretrained(base_model_hf_repo_id_or_path)
# base_model.push_to_hub(trained_model_hf_repo_id, token=hub_token)
tokenizer.push_to_hub(trained_model_hf_repo_id, token=hub_token)

# Push the custom model to the target Hugging Face repo
model.push_to_hub(trained_model_hf_repo_id, token=hub_token)


Kaggle issue

Hello, so I am having an issue running EasyDel on Kaggle TPUs, namely, that the jupyter kernel dies while importing the libraries.

QLoRA Finetune Example

Can you post an example of how to fine-tune a llm with 4-bit QLoRA on a TPU? Thanks for any help you can provide.

Loss increases randomly

Hello once again, I am seeing some weird behavior with my loss whenever I use EaSyDel for fine-tuning, no matter the dataset.
image

These are my training args:

train_args = TrainArguments(
    model_class=type(model),
    configs_to_init_model_class=configs_to_init_model_class,
    custom_rule=config.get_partition_rules(True),
    model_name='EasyDelLLama2',
    num_train_epochs=1,
    learning_rate=4e-05,
    learning_rate_end=1.5e-06,
    warmup_steps=156,
    optimizer='adamw',
    scheduler='warm_up_linear',
    weight_decay=0.01,
    total_batch_size=32,
    max_steps=None,
    do_train=True,
    do_eval=False,
    backend='tpu',
    max_length=max_length,
    gradient_checkpointing='nothing_saveable',
    sharding_array=(1, -1, 1),
    use_pjit_attention_force=False,
    gradient_accumulation_steps=1,
    remove_ckpt_after_load=True,
    ids_to_pop_from_dataset=['token_type_ids'],
    loss_remat='',
    is_left_padded=True,
    dtype=jax.numpy.bfloat16
)

trainer = CausalLMTrainer(
    train_args,
    dataset_train,
    ckpt_path=None
)

Training dataset is Open-Platypus (left-padded) and the model is Sheared-Llama-2.7B.

Error while serving model as per documentation,

Describe the bug
Error while serving a toy model using EasyDel with the following exception.

To Reproduce

Context

  • Using Google TPU VM
  • Follow the instructions as suggested here
    huggingface_repo_id_or_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
  • Use TinyLLM model TinyLlama/TinyLlama-1.1B-Chat-v1.0
python sev-tiny.py
Traceback (most recent call last):
  File "/home/***/research/EasyDeL/sev-tiny.py", line 23, in <module>
    params=model.params,
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 271, in params
    raise ValueError(
ValueError: `params` cannot be accessed from model when the model is created with `_do_init=False`. You must call `init_weights` manually and store the params outside of the model and pass it explicitly where needed.

Distillation

I would like to try a distillation where I calculate the softmax prob of the logits of the teacher model using my input dataset, save it, and use it as a soft target for the student model to train. Is this possible enough in easydel?

Training on TPU Using Flash Attention

Hi, I tried finetune a model on TPU VM v3-8. when not using flash attention, it works. However, when I set config.use_flash_attention =True, an error occurs: block_q=1024 should be smaller or equal to q_seq_len=1.

When I tried to set config.q_seq_len = 4096. it doesn't work, it still report the same error: block_q=1024 should be smaller or equal to q_seq_len=1.

Below is my Code:

def main(argv):

dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
dataset_train = dataset['test_sft'].map(formatting_func, num_proc=12)
dataset_train = dataset_train.remove_columns(['prompt','prompt_id','messages'])

params, config = llama_from_pretrained(FLAGS.pretrained_model_name_or_path,jax.devices("cpu")[0])
config.use_flash_attention =True
config.q_seq_len = 4096

config.flash_attn_key_chunk_size = 1

config.flash_attn_query_chunk_size = 1

train_args = TrainArguments(
    model_class=EasyDel.modules.FlaxLlamaForCausalLM,
    configs_to_init_model_class={
        'config': config,
        'dtype': get_dtype(FLAGS.dtype),
        'param_dtype': get_dtype(FLAGS.dtype)
    },
    custom_rule=config.get_partition_rules(True),
    model_name=FLAGS.project_name,
    num_train_epochs=FLAGS.num_train_epochs,
    learning_rate=FLAGS.learning_rate,
    learning_rate_end=FLAGS.learning_rate_end,
    optimizer=FLAGS.optimizer,
    scheduler=FLAGS.scheduler,
    weight_decay=0.01,
    total_batch_size=1,
    gradient_accumulation_steps=32,
    max_steps=FLAGS.max_steps,
    do_train=FLAGS.do_train,
    do_eval=FLAGS.do_eval,
    do_test=FLAGS.do_test,
    backend=FLAGS.backend,
    max_length=FLAGS.max_sequence_length,
    gradient_checkpointing='nothing_saveable',
    sharding_array=(1, -1, 1, 1),
    use_pjit_attention_force=False,
    
    remove_ckpt_after_load=FLAGS.remove_ckpt_after_load,

)

trainer = CausalLanguageModelTrainer(train_args,
                                     dataset_train=dataset_train,
                                     dataset_eval=dataset_train['eval'] if FLAGS.do_eval else None,
                                     checkpoint_path=FLAGS.checkpoint_path)
output = trainer.train(
    model_parameters=flax.core.FrozenDict({'params': params})
)
# Done You can simply train any llama LLM that you want in less than 50 lines of code

if name == "main":
app.run(main)

/root
/usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.
table = cls._concat_blocks(blocks, axis=0)
Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2/2 [00:14<00:00, 7.20s/it]
Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function
wandb: Tracking run with wandb version 0.16.2
wandb: W&B syncing is set to offline in this directory.
wandb: Run wandb online or set WANDB_MODE=online to enable cloud syncing.
Time For configure dataloaders (ms) : 0.2694129943847656
I0119 04:13:23.447986 132476669524864 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/root/train.py", line 237, in
app.run(main)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/train.py", line 226, in main
trainer = CausalLanguageModelTrainer(train_args,
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 243, in init
self.init_functions()
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 300, in init_functions
self.model, self.tx, self.scheduler, self.config = self.configure_model()
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 397, in configure_model
model = self.arguments.model_class(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 657, in init
super().init(config, module, input_shape=input_shape,
File "/usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 223, in init
params_shape_tree = jax.eval_shape(init_fn, self.key)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 692, in init_weights
module_init_outputs = self.module.init(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call
outputs = self.model(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call
outputs = self.layers(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call
layer_outputs = block(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call
attn_outputs = self.self_attn(
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
return rematted(variable_groups, rng_groups, *dyn_args)
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
y = fn(scope, *args)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call
attn_output = smart_flash_attention(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention
attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention
return _flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention
return _flash_attention_impl(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 590, in _flash_attention_impl
_verify_block("block_q", "q_seq_len", block_q, q_seq_len, should_divide=False)
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 1689, in _verify_block
raise ValueError(
ValueError: block_q=1024 should be smaller or equal to q_seq_len=1
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/root/train.py", line 237, in
app.run(main)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/train.py", line 226, in main
trainer = CausalLanguageModelTrainer(train_args,
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 243, in init
self.init_functions()
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 300, in init_functions
self.model, self.tx, self.scheduler, self.config = self.configure_model()
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 397, in configure_model
model = self.arguments.model_class(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 657, in init
super().init(config, module, input_shape=input_shape,
File "/usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 223, in init
params_shape_tree = jax.eval_shape(init_fn, self.key)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 692, in init_weights
module_init_outputs = self.module.init(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call
outputs = self.model(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call
outputs = self.layers(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call
layer_outputs = block(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call
attn_outputs = self.self_attn(
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
return rematted(variable_groups, rng_groups, *dyn_args)
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
y = fn(scope, *args)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call
attn_output = smart_flash_attention(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention
attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention
return _flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention
return _flash_attention_impl(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 590, in _flash_attention_impl
_verify_block("block_q", "q_seq_len", block_q, q_seq_len, should_divide=False)
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 1689, in _verify_block
raise ValueError(
ValueError: block_q=1024 should be smaller or equal to q_seq_len=1
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /root/wandb/offline-run-20240119_041323-e4ftdqry
wandb: Find logs at: ./wandb/offline-run-20240119_041323-e4ftdqry/logs

Then I tried to set config.flash_attn_key_chunk_size = 1 and config.flash_attn_query_chunk_size = 1. another error occured: TypeError: fori_loop() got an unexpected keyword argument 'unroll'

/root
/usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.
table = cls._concat_blocks(blocks, axis=0)
Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2/2 [00:14<00:00, 7.42s/it]
Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function
wandb: Tracking run with wandb version 0.16.2
wandb: W&B syncing is set to offline in this directory.
wandb: Run wandb online or set WANDB_MODE=online to enable cloud syncing.
Time For configure dataloaders (ms) : 0.25153160095214844
I0119 04:18:41.780579 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:41.856101 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:41.903009 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:41.948627 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:41.995650 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.042542 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.089266 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.137497 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.184210 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.230629 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.276757 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.324035 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.370467 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.417459 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.465906 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.514631 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.562183 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.608217 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.654218 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.700267 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.746966 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.793409 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.841850 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.889965 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.936056 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:42.982169 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.028893 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.074969 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.122670 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.169237 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.216310 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.262662 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
Time For configure Model ,Optimizer ,Scheduler and Config (ms) : 1676.0611534118652
I0119 04:18:43.363086 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.410817 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.457525 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.503836 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.549518 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.595661 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.643612 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.690201 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.736696 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.781970 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.828232 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.873488 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.919017 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:43.964708 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.010466 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.057410 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.103181 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.151406 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.197674 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.244695 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.291322 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.337508 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.384516 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.430277 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.476433 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.522479 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.568389 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.614069 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.660615 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.707407 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.754585 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:44.800200 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
I0119 04:18:45.564801 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
Time For configure functions and sharding them (ms) : 2244.947671890259
Action : Sharding Passed Parameters
Model Contain 6.929256448 Billion Parameters
0%| | 0/7220 [00:00<?, ?it/s]I0119 04:20:07.023622 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/root/train.py", line 223, in
app.run(main)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run
run_main(main, args)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in run_main
sys.exit(main(argv))
File "/root/train.py", line 216, in main
output = trainer.train(
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 708, in train
sharded_state, loss, accuracy = self.sharded_train_step_fn(
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 89, in casual_language_model_train_step
(loss
, accuracy__), grad = grad_fn(state.params)
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 80, in calculate_loss
logits = state.apply_fn(params=params, **batch,
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 809, in call
outputs = self.module.apply(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call
outputs = self.model(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call
outputs = self.layers(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call
layer_outputs = block(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call
attn_outputs = self.self_attn(
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
return rematted(variable_groups, rng_groups, *dyn_args)
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
y = fn(scope, *args)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call
attn_output = smart_flash_attention(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention
attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention
return _flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention
return _flash_attention_impl(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 746, in _flash_attention_impl
o, *aux = pl.pallas_call(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 379, in wrapped
grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 338, in _trace_to_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals,
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 337, in _flash_attention_kernel
kernel((batch_idx, 0), q_tile_ref, *args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 384, in _flash_attention_kernel_single_batch
def run():
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/utils.py", line 29, in _wrapped
f()
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 388, in run
def body(i, _):
TypeError: fori_loop() got an unexpected keyword argument 'unroll'
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/root/train.py", line 223, in
app.run(main)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run
run_main(main, args)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in run_main
sys.exit(main(argv))
File "/root/train.py", line 216, in main
output = trainer.train(
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 708, in train
sharded_state, loss, accuracy = self.sharded_train_step_fn(
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 89, in casual_language_model_train_step
(loss
, accuracy__), grad = grad_fn(state.params)
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 80, in calculate_loss
logits = state.apply_fn(params=params, **batch,
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 809, in call
outputs = self.module.apply(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call
outputs = self.model(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call
outputs = self.layers(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call
layer_outputs = block(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call
attn_outputs = self.self_attn(
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
return rematted(variable_groups, rng_groups, *dyn_args)
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
y = fn(scope, *args)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call
attn_output = smart_flash_attention(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention
attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention
return _flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention
return _flash_attention_impl(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 746, in _flash_attention_impl
o, *aux = pl.pallas_call(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 379, in wrapped
grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 338, in _trace_to_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals,
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 337, in _flash_attention_kernel
kernel((batch_idx, 0), q_tile_ref, *args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 384, in _flash_attention_kernel_single_batch
def run():
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/utils.py", line 29, in _wrapped
f()
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 388, in run
def body(i, _):
TypeError: fori_loop() got an unexpected keyword argument 'unroll'
wandb:
wandb: Run history:
wandb: Number of Model Parameters (Billion) โ–
wandb:
wandb: Run summary:
wandb: Number of Model Parameters (Billion) 6.92926
wandb:
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /root/wandb/offline-run-20240119_041841-decs7x5r
wandb: Find logs at: ./wandb/offline-run-20240119_041841-decs7x5r/logs

Could you please guide me how to use flash attention on TPU when training model using easydel?
Thank you so much!

Exception while running any model - einops.EinopsError: Error while processing rearrange-reduction pattern "b (c n) d -> b c n d".

Describe the bug

python -m examples.serving.causal-lm.llama-2-chat   --pretrained_model_name_or_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0" --max_sequence_length=1024   --max_new_tokens=256  --max_compile_tokens=32 --temperature=0.6   --top_p=0.95 --top_k=50 --dtype="fp16"

   layer_outputs = block(
  File "/home/neo/research/easydel/.venv/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 525, in __call__
    feed_forward_hidden_states = block_wise_ffn(
  File "/home/neo/research/easydel/.venv/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 432, in block_wise_ffn
    inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
  File "/home/neo/research/easydel/.venv/lib/python3.10/site-packages/einops/einops.py", line 483, in rearrange
    return reduce(cast(Tensor, tensor), pattern, reduction='rearrange', **axes_lengths)
  File "/home/neo/research/easydel/.venv/lib/python3.10/site-packages/einops/einops.py", line 420, in reduce

I have installed latest version of Easydel from github main

Mixtral 8x7B support?

Hi Erfan, I tried testing the Mixtral module and it failed returning "Returning Empty Params". Is this module working?

Issue with configs when running serve

I'm very interested in getting this working. I am trying to get Mistral running on a TPU-VMs at GCP.

Using your example command:
python3 -m examples.serving.causal-lm.llama-2-chat --repo_id='mistralai/Mistral-7B-v0.1' --max_length=4096 --max_new_tokens=2048 --max_stream_tokens=32 --temperature=0.6 --top_p=0.95 --top_k=50 --dtype='fp16' --use_prefix_tokenizer

I get this error after it downloads the model and loads the checkpoint shards:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/app/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 210, in
server = Llama2Host.load_from_torch(
File "/app/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 54, in load_from_torch
return cls.load_from_params(
File "/usr/local/lib/python3.10/dist-packages/EasyDel/serve/jax_serve.py", line 328, in load_from_params
server = cls(config=config)
File "/app/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 38, in init
super().init(config=config)
File "/usr/local/lib/python3.10/dist-packages/EasyDel/serve/jax_serve.py", line 109, in init
assert config is None or isinstance(config, JaxServerConfig), 'config can be None or JaxServerConfig Type'
AssertionError: config can be None or JaxServerConfig Type

Use custom tokenizer

I want to add a token to the tokenizer of my Mistral model and train it. what settings should I change? Only embedding size?

Error while finetuning Tinyllama on Kaggle TPU

Describe the bug
An error while training tiny llama on kaggle

/root
/usr/local/lib/python3.10/site-packages/pydantic/_internal/_fields.py:149: UserWarning: Field "model_name" has conflict with protected namespace "model_".

You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
  warnings.warn(
Information :  track_memory is set to False by default inorder make make training faster. you can turn it on with just passing `track_memory=True` in TrainArguments
Downloading data files: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00, 7319.90it/s]
Extracting data files: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00, 1059.70it/s]
Generating train split: 176 examples [00:00, 37836.88 examples/s]
Map (num_proc=12): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 176/176 [00:00<00:00, 416.75 examples/s]
/usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
  table = cls._concat_blocks(blocks, axis=0)
Map (num_proc=12):   0%|                         | 0/176 [00:00<?, ? examples/s]/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
Map (num_proc=12): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 176/176 [00:00<00:00, 437.34 examples/s]
Warning : In case of using `finetune = True` and Passing `checkpoint_path = None` you should pass parameters in train function
wandb: Currently logged in as: jchauhan (safedep). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.3
wandb: Run data is saved locally in /root/wandb/run-20240216_181547-lgbd2mo5
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run twinkling-fish-7
wandb: โญ๏ธ View project at https://wandb.ai/safedep/EasyDeL-my_first_model_to_train_using_easydel
wandb: ๐Ÿš€ View run at https://wandb.ai/safedep/EasyDeL-my_first_model_to_train_using_easydel/runs/lgbd2mo5
Time Took to Complete Task configure dataloaders (microseconds) : 0.4432201385498047
Time Took to Complete Task configure Model, Optimizer, Scheduler and Config (microseconds) : 1331.115484237671
Time Took to Complete Task configure functions and sharding them (microseconds) : 1449.2170810699463
Action : Sharding Passed Parameters
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/train.py", line 140, in <module>
    output = trainer.train(flax.core.FrozenDict({"params": params}))
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 488, in train
    sharded_state, shard_fns, gather_fns = self.initialize_state(
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 402, in initialize_state
    model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 403, in <lambda>
    lambda f, x: f(x),
  File "/usr/local/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py", line 50, in shard_fn
    return jax_shard_function(tensor).block_until_ready()
ValueError: Memory kinds passed to jax.jit does not match memory kind on the respective arg. Got pjit memory kind: tpu_hbm, arg memory kind: None for arg shape: float32[2048,32000]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/train.py", line 140, in <module>
    output = trainer.train(flax.core.FrozenDict({"params": params}))
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 488, in train
    sharded_state, shard_fns, gather_fns = self.initialize_state(
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 402, in initialize_state
    model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 403, in <lambda>
    lambda f, x: f(x),
  File "/usr/local/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py", line 50, in shard_fn
    return jax_shard_function(tensor).block_until_ready()
ValueError: Memory kinds passed to jax.jit does not match memory kind on the respective arg. Got pjit memory kind: tpu_hbm, arg memory kind: None for arg shape: float32[2048,32000]

To Reproduce

huggingface_repo_id_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 1024
tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="my_first_model_to_train_using_easydel",
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_init_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=1,
    max_steps=10000,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in fully FSDP automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)

    

Text Generation with Mixtral fails

Throws this error when generate is called
TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1, 4096) for operand shape (1, 20).
image

To Reproduce

import copy
import jax
from EasyDel import AutoEasyDelModelForCausalLM, AutoEasyDelConfig, get_modules_by_type
from transformers import AutoTokenizer
from transformers import GenerationConfig
from transformers import MixtralForCausalLM
from EasyDel import MixtralConfig, FlaxMixtralForCausalLM
from EasyDel.transform.easydel_transform import huggingface_to_easydel

pretrained_model_name_or_path = "/LLMs/Mixtral-8x7B-v0.1/"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        dtype=jax.numpy.bfloat16,
        param_dtype=jax.numpy.bfloat16,
        precision=jax.lax.Precision("fastest"),
        device=jax.devices('cpu')[0]
)

seq_len = 128
config = MixtralConfig(
        hidden_size=256,
        num_attention_heads=8,
        num_hidden_layers=1,
        num_key_value_heads=4,
        intermediate_size=512,
        num_local_experts=8,
        max_position_embeddings=seq_len
)

torch_model = MixtralForCausalLM(
        config=copy.deepcopy(config)
)

params = {"params":
        huggingface_to_easydel(
            torch_model.state_dict(),
            embedding_layer_names=["embed_tokens"],
            device=jax.devices("cpu")[0]
        )
}
    
tokenizer.pad_token = tokenizer.eos_token
tokens = tokenizer("Can you tell me who is the current president of the united states?", max_length=4096, padding='max_length', return_tensors='jax')
input_ids, attention_mask = tokens.input_ids, tokens.attention_mask 
predict = model.generate(
            input_ids,
            attention_mask=attention_mask,
            params=params)

Example shown on https://pypi.org/project/EasyDeL/ to finetune tinyllama raise exception on kaggle

Describe the bug

Action : Sharding Passed Parameters
Model Contain 1.100048384 Billion Parameters
  0%|          | 0/1500 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 101
     93 # you can do the same for evaluation process dataset
     95 trainer = CausalLanguageModelTrainer(
     96     train_arguments,
     97     dataset_train,
     98     checkpoint_path=None
     99 )
--> 101 output = trainer.train(flax.core.FrozenDict({"params": params}))
    102 print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:509, in CausalLanguageModelTrainer.train(self, model_parameters, state)
    507 try:
    508     for epoch in range(self.arguments.num_train_epochs):
--> 509         for batch in self.dataloader_train:
    510             current_step += 1
    511             if (
    512                     self.arguments.step_start_point is not None
    513                     and
    514                     self.arguments.step_start_point > current_step
    515             ):

File /usr/local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO([https://github.com/pytorch/pytorch/issues/76750)](https://github.com/pytorch/pytorch/issues/76750)%3C/span%3E)
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File /usr/local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:674, in _SingleProcessDataLoaderIter._next_data(self)
    672 def _next_data(self):
    673     index = self._next_index()  # may raise StopIteration
--> 674     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    675     if self._pin_memory:
    676         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File /usr/local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:54, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     52 else:
     53     data = self.dataset[possibly_batched_index]
---> 54 return self.collate_fn(data)

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:179, in CausalLanguageModelTrainer.create_collate_function.<locals>.collate_fn(batch)
    175     else:
    176         corrected_sequence = [
    177             jnp.array(f[key])[..., :max_sequence_length] for f in batch
    178         ]
--> 179     results[key] = jnp.stack(corrected_sequence).reshape(
    180         -1,
    181         corrected_sequence[0].shape[-1]
    182     )
    183 return results

File /usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:1796, in stack(arrays, axis, out, dtype)
   1794 for a in arrays:
   1795   if shape(a) != shape0:
-> 1796     raise ValueError("All input arrays must have the same shape.")
   1797   new_arrays.append(expand_dims(a, axis))
   1798 return concatenate(new_arrays, axis=axis, dtype=dtype)

ValueError: All input arrays must have the same shape

To Reproduce
Install dependencies

# !pip install git+https://github.com/erfanzar/EasyDeL.git
!pip install EasyDeL==0.0.50
!pip install sentencepiece
!pip install jaxlib==0.4.19
!pip install jax[tpu]==0.4.19 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 
!apt-get update && apt-get upgrade -y
!apt-get install golang -y 

Run the example on kaggle using TPUs

Step time increasing as training progresses

In one of the longer training runs that is now running on a tpu-v3-8 I noticed the training ETA kept getting later and later.
Also in the step-time wandb log (picture below) the higher the step number, the longer the lookup time.

image

Any ideas what could be the cause? I looked a bit at DataLoader prefetch_factor but its only available when using multiprocessing / num_workers > 0

PS: Thanks for creating EasyDel - its amazing what you've created!

ValueError: Dict key mismatch; expected keys: ['transformer'];

Getting the following error while finetuning gpt2 model

To Reproduce

  • Use the example training code and change the base model to gpt2
  • Use data samples to train the model
^[[36mTime Took to Complete Task configure dataloaders (microseconds) : ^[[97m0.33593177795410156^[[0m^[[0m
^[[36mTime Took to Complete Task configure Model ,Optimizer ,Scheduler and Config (microseconds) : ^[[97m677.0551204681396^[[0m^[[0m
^[[36mTime Took to Complete Task configure functions and sharding them (microseconds) : ^[[97m790.2214527130127^[[0m^[[0m
^[[31mAction : ^[[0mSharding Passed Parameters
Traceback (most recent call last):
  File "/home/xxx/research/transformers/train_gpt_easydel.py", line 92, in <module>
    output = trainer.train(flax.core.FrozenDict({"params": params}))
  File "/home/xxx/research/EasyDeL/lib/python/EasyDel/trainer/causal_language_model_trainer.py", line 669, in train
    sharded_state, shard_fns, gather_fns = self.init_state(
  File "/home/xxx/research/EasyDeL/lib/python/EasyDel/trainer/causal_language_model_trainer.py", line 585, in init_state
    params = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 243, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 243, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Dict key mismatch; expected keys: ['transformer']; dict: {'transformer': {'wte': {'embedding': array([[-0.11010301, -0.03926672,  0.03310751, ..., -0.1363697 ,


**Example Code **

import jax.numpy
from EasyDel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

huggingface_repo_id_or_path = "gpt2"
max_length = 512
trained_model_name = "chnageme"
easydel_trained_model_name = f"{trained_model_name}.easydel"
training_data_files="changeme.json"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name=easydel_trained_model_name,
    num_train_epochs=3,
    configs_to_init_model_class=configs_to_init_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=1,
    max_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in fully FSDP automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)


def ultra_chat_prompting_process(
        data_chunk
):
    return {"prompt": data_chunk['train']}


tokenization_process = lambda data_chunk: tokenizer(
    data_chunk["prompt"],
    add_special_tokens=False,
    max_length=max_length,
    padding="max_length"
)

dataset = load_dataset("json", data_files=training_data_files)
dataset_train = dataset["train"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
    tokenization_process,
    num_proc=12,
    remove_columns=dataset_train.column_names
)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

RESOURCE_EXHAUSTED: XLA:TPU compile permanent error

To Reproduce
Use a TPU v2_8 with vm architecture

Time For configure functions and sharding them (ms) : 2012.7429962158203
Action : Sharding Passed Parameters
Model Contain  1.100048384  Billion Parameters
  0%|                                                                                                | 0/12000 [00:00<?, ?it/s]jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/***/research/EasyDeL/1.py", line 101, in <module>
    output = trainer.train(flax.core.FrozenDict({"params": params}))
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 708, in train
    sharded_state, loss, accuracy = self.sharded_train_step_fn(
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 21.90G of 7.48G hbm. Exceeded hbm capacity by 14.42G.


Output of Tiny Llama using Easydel vs hugging face transformer api differs

Describe the bug
A clear and concise description of what the bug is.

To Reproduce

Output from Hugging Face Transformer APIs on local env

<|system|>
You are an oracle who knows the anwers of everything. There should not any uncertainity in your answers</s>
<|user|>
how many stars there in the universe?
</s>
<|assistant|>

<|system|>
You are an oracle who knows the anwers of everything. There should not any uncertainity in your answers</s>
<|user|>
how many stars there in the universe?
</s>
<|assistant|>
The number of stars in the universe is estimated to be in the millions or billions. However, the precise number is still being studied and calculated. It is believed that the current number of stars in the universe is at least 100 billion, with a range of 10 to 100 trillion. The number of stars and galaxies is increasing with time, indicating that the universe is expanding.

Output from tinyllama using EasyDel

command to serve the tinyllama

run examples/serving/causal-lm/tinyllama-2-chat.py --pretrained_model_name_or_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
   ...:  --max_length=2048   --max_new_tokens=2048 --max_compile_tokens=32 --temperature=0.6   --top_p=0.95 --top_k=50   --dtyp
   ...: e="fp16" --use_prefix_tokenizer --mesh_axes_shape 1 -1 1 1
llama
/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/gradio-4.12.0-py3.10.egg/gradio/components/base.py:182: UserWarning: show_label has no effect when container is False.
  warnings.warn("show_label has no effect when container is False.")
Sharding Params: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 201/201 [00:25<00:00,  7.74it/s]
Compiling Model Forwards Greedy/Non-Greedy(Generate)
Compiling Greedy Functions
Compiling Non-Greedy(Generate) Functions
Launching App ...
/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/gradio-4.12.0-py3.10.egg/gradio/components/base.py:182: UserWarning: show_label has no effect when container is False.
  warnings.warn("show_label has no effect when container is False.")
Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://d74594468abe170a71.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
<IPython.core.display.HTML object>
Launching Server APIS (Fire) ...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

INFO:     Started server process [975861]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:2059 (Press CTRL+C to quit)

Output

image

the prompt was

<|system|>
You are an oracle who knows the answers of everything. There should not any uncertainity in your answers</s>
<|user|>
how many stars there in the universe?
</s>
<|assistant|>

model loading

I received the file Mistral-1.566301941871643-69 at the end of the model's training, and I was wondering if there is a way to convert this save file to model.bin or load it to tpu to see if it works.

Thank you for the support!

Will this project support lora?

Thanks for giving me an easy way finetuning model on TPU. I tried finetune a full model using this project on kaggle's tpu and it works.

I'm wondering whether there is a schedule for supporting finetune a lora or qlora model? it's a more practical way for those only have access to limited computing resources like me...

Error while training a Phi2 model

To Reproduce

Time Took to Complete Task configure dataloaders (microseconds) : 0.3025531768798828
Time Took to Complete Task configure Model ,Optimizer ,Scheduler and Config (microseconds) : 2130.923271179199
Traceback (most recent call last):
  File "/home/xxx/research/transformers/train_phi2_easydel.py", line 86, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/home/xxx/research/EasyDeL/lib/python/EasyDel/trainer/causal_language_model_trainer.py", line 244, in __init__
    self.init_functions()
  File "/home/xxx/research/EasyDeL/lib/python/EasyDel/trainer/causal_language_model_trainer.py", line 310, in init_functions
    funcs = self.configure_functions()
  File "/home/xxx/research/EasyDeL/lib/python/EasyDel/trainer/causal_language_model_trainer.py", line 477, in configure_functions
    create_sharded_state_from_params_fn = pjit(
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 762, in pjit
    static_argnames) = pre_infer_params(
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 309, in pre_infer_params
    in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 1093, in prepare_axis_resources
    new_entries.append(ParsedPartitionSpec.from_user_input(
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 985, in from_user_input
    raise TypeError(f"{arg_name} are expected to be "
TypeError: in_shardings leaf specifications are expected to be PartitionSpec instances or None, but got *
Traceback (most recent call last):
  File "/home/xxx/research/transformers/train_phi2_easydel.py", line 86, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/home/xxx/research/EasyDeL/lib/python/EasyDel/trainer/causal_language_model_trainer.py", line 244, in __init__
    self.init_functions()
  File "/home/xxx/research/EasyDeL/lib/python/EasyDel/trainer/causal_language_model_trainer.py", line 310, in init_functions
    funcs = self.configure_functions()
  File "/home/xxx/research/EasyDeL/lib/python/EasyDel/trainer/causal_language_model_trainer.py", line 477, in configure_functions
    create_sharded_state_from_params_fn = pjit(
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 762, in pjit
    static_argnames) = pre_infer_params(
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 309, in pre_infer_params
    in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 1093, in prepare_axis_resources
    new_entries.append(ParsedPartitionSpec.from_user_input(
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 985, in from_user_input
    raise TypeError(f"{arg_name} are expected to be "
TypeError: in_shardings leaf specifications are expected to be PartitionSpec instances or None, but got *

Example Code

import jax.numpy
from EasyDel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

huggingface_repo_id_or_path = "microsoft/phi-2"
max_length = 512
trained_model_name = "changeme"
easydel_trained_model_name = f"{trained_model_name}.easydel"
training_data_files="changeme.json"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, trust_remote_code=True)

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name=easydel_trained_model_name,
    num_train_epochs=3,
    configs_to_init_model_class=configs_to_init_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=1,
    max_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in fully FSDP automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)


def ultra_chat_prompting_process(
        data_chunk
):
    return {"prompt": data_chunk['train']}


tokenization_process = lambda data_chunk: tokenizer(
    data_chunk["prompt"],
    add_special_tokens=False,
    max_length=max_length,
    padding="max_length"
)

dataset = load_dataset("json", data_files=training_data_files)
dataset_train = dataset["train"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
    tokenization_process,
    num_proc=12,
    remove_columns=dataset_train.column_names
)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

Dependency Error on latest version

Describe the bug
pip install easydel
/usr/local/lib/python3.8/dist-packages/pkg_resources/init.py:123: PkgResourcesDeprecationWarning: 0.1.36ubuntu1 is an invalid version and will not be supported in a future release
warnings.warn(
/usr/local/lib/python3.8/dist-packages/pkg_resources/init.py:123: PkgResourcesDeprecationWarning: 0.23ubuntu1 is an invalid version and will not be supported in a future release
warnings.warn(
Collecting easydel
Downloading EasyDeL-0.0.41-py3-none-any.whl (194 kB)
|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 194 kB 14.1 MB/s
ERROR: Could not find a version that satisfies the requirement chex~=0.1.84 (from easydel) (from versions: 0.0.1, 0.0.2, 0.0.3, 0.0.4, 0.0.5, 0.0.6, 0.0.7, 0.0.8, 0.0.9, 0.1.0, 0.1.1, 0.1.2, 0.1.3, 0.1.4, 0.1.5, 0.1.6, 0.1.7)
ERROR: No matching distribution found for chex~=0.1.84 (from easydel)

The function call llama_from_pretrained expects a device param that is missing while call it

Describe the bug
It seems the code is broken as the at main branch and also at tag 0.0.42. The function call llama_from_pretrained expects a device param that is missing while call it

To Reproduce

  1. checkout the latest code
  2. Instal and setup
  3. Run it
python -m examples.serving.causal-lm.llama-2-chat   --pretrained_model_name_or_path="mediocredev/open-llama-3b-v2-chat" --max_length=4096   --max_new_tokens=2048 --max_compile_tokens=32 --temperature=0.6   --top_p=0.95 --top_k=50   --dtype="fp32" --use_prefix_tokenizer
Traceback (most recent call last):
  File "/home/**/.pyenv/versions/3.10.0/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/neo/.pyenv/versions/3.10.0/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/***/research/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 193, in <module>
    server = Llama2Host.load_from_torch(
  File "/home/***/research/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 53, in load_from_torch
    param, config_model = llama_from_pretrained(
TypeError: llama_from_pretrained() missing 1 required positional argument: 'device'

Running every model gives an error Shapes must be 1D sequences of concrete values of integer type

To Reproduce
Run it on any model and it gives the following error

python -m examples.serving.causal-lm.llama-2-chat   --pretrained_model_name_or_path="mediocredev/open-llama-3b-v2-chat" --max_length=4096   --max_new_tokens=2048 --max_compile_tokens=32 --temperature=0.6   --top_p=0.95 --top_k=50   --dtype="fp32" --use_prefix_tokenizer
Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2/2 [00:01<00:00,  1.21it/s]
JAXServerConfig(host='0.0.0.0', port=2059, batch_size=1, contains_auto_format=False, max_length=4096, max_new_tokens=2048, max_compile_tokens=32, temperature=0.6, top_p=0.95, top_k=50, logging=False, mesh_axes_names=['dp', 'fsdp', 'tp', 'sp'], mesh_axes_shape=[(1, -1, 1, 1)], generation_ps=PartitionSpec('dp', 'fsdp'), dtype='fp32', stream_tokens_for_gradio=True, use_prefix_tokenizer=True, pre_compile=True)
Traceback (most recent call last):
  File "/home/***/.pyenv/versions/3.10.0/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/***/.pyenv/versions/3.10.0/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/***/research/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 195, in <module>
    server = Llama2Host.load_from_torch(
  File "/home/***/research/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 66, in load_from_torch
    return cls.load_from_params(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 476, in load_from_params
    server = cls(config=config)
  File "/home/***/research/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 40, in __init__
    super().__init__(config=config)
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 123, in __init__
    array = jnp.ones((len(jax.devices()), 1)).reshape(self.config.mesh_axes_shape)
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 155, in _reshape
    newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 123, in _compute_newshape
    newshape = core.canonicalize_shape(newshape)  # type: ignore[arg-type]
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 2130, in canonicalize_shape
    raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [(1, -1, 1, 1)].

Trouble running demo

Hello, thanks for this really cool repository. I'm recently learning about pjit and your repo is a valuable reference resource.

I was having some issues running the example code. In particular, if I run

FATAL Flags parsing error:
  flag --dataset_name=None: Flag --dataset_name must have a value other than None.
  flag --ckpt_path=None: Flag --ckpt_path must have a value other than None.
Pass --helpshort or --helpfull to see help on flags

It occurs that I would need to build the models according to the instructions here, https://github.com/erfanzar/EasyDeL#step-one. I gave it a try and got

import jax
from EasyDel.transform import falcon_convert_pt_to_flax_7b
from fjutils.utils import save_ckpt
from transformers import AutoModelForCausalLM
number_of_layers = 32  # its 32 hidden layers for Mpt 7B
device = jax.devices('cpu')[0]  # offload on CPU
model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-7b", trust_remote_code=True, use_auth_token=True)
pytorch_model_state_dict = model.state_dict()
flax_params = falcon_convert_pt_to_flax_7b(pytorch_model_state_dict, number_of_layers, device)
save_ckpt(flax_params, 'flax_param_easydel')
/admin/home/costa/.cache/pypoetry/virtualenvs/easydel-2Au1A9J8-py3.8/lib/python3.8/site-packages/transformers/configuration_utils.py:483: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.
  warnings.warn(
/admin/home/costa/.cache/pypoetry/virtualenvs/easydel-2Au1A9J8-py3.8/lib/python3.8/site-packages/transformers/modeling_utils.py:2193: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.
  warnings.warn(
Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2/2 [00:36<00:00, 18.02s/it]
Traceback (most recent call last):
  File "test.py", line 12, in <module>
    flax_params = falcon_convert_pt_to_flax_7b(pytorch_model_state_dict, number_of_layers, device)
  File "/fsx/costa/EasyDeL/EasyDel/transform/falcon.py", line 29, in falcon_convert_pt_to_flax_7b
    state_dict_flax[('transformer', 'h', f'{i}', 'post_attention_layernorm', 'scale')] = state_dict_pt[
KeyError: 'transformer.h.0.post_attention_layernorm.weight'

Do you have any end-to-end hello_world script/command that just runs? Those would really help :)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.