Giter Club home page Giter Club logo

llm-shearing's Introduction

๐Ÿฆ™ Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning

๐ŸŒŸ ArXiv Preprint | Blog Post

Base models: Sheared-LLaMA-1.3B | Sheared-LLaMA-2.7B | Sheared-Pythia-160m
Pruned Models without Continued Pre-training: Sheared-LLaMA-1.3B-Pruned, Sheared-LLaMA-2.7B-Pruned
Instruction-tuned models: Sheared-LLaMA-1.3B-ShareGPT | Sheared-LLaMA-2.7B-ShareGPT

Thank you for your interest in our work! This is a joint work by Mengzhou Xia, Tianyu Gao, Zhiyuan Zeng, and Danqi Chen. Here, we provide our codebase for Sheared-LLaMA's pruning and continued pre-training algorithms :) We find that pruning strong base models is an extremely cost-effective way to get strong small-scale language models compared to pre-training them from scratch. The following graph shows that given the existence of Llama-2-7B model (pre-trained with 2T tokens), pruning it produces a model as strong as an OpenLLaMA model with 3% of its pre-training cost.

teaser

Update

  • [12/19/2023] Updated the evaluation scripts and pruning logs in the repo.
  • [11/22/2023] We released the instruction-tuned models Sheared-LLaMA-1.3B-ShareGPT and Sheared-LLaMA-2.7B-ShareGPT.
  • [11/19/2023] We released the Sheared-Pythia-160m model developed at early stages. It was produced using the same shearing recipe and the Pile dataset.
  • [11/05/2023] We released the code on LLM-Shearing - excited to see it being applied to more models of different scales.
  • [10/10/2023] We released the Sheared-LLaMA paper, two Sheared LLaMA models and tweeted about it ๐Ÿš€!

๐Ÿ”— Quick Links

Brief Introduction

This codebase is built based on MosaicML's amazing Composer package, which is specially designed and optimized for large language model pre-training. The entire implementation, including the pruning logic and the dynamic batch loading logic, are implemented as callback functions without touching the vanilla Composer trainer. Here's a concise overview of each folder within the codebase:

  • shearing.data: Contains sample data and scripts for data processing.
  • shearing.datasets: Implements customized datasets to enable dynamic data loading.
  • shearing.callbacks: Implements dynamic loading callbacks and pruning callbacks.
  • shearing.models: Implements the model files.
  • shearing.scripts: Contains scripts for running the code.
  • shearing.utils: Includes all utility functions, such as model conversion and pruning tests.
  • train.py: main entry of running the code

Install Requirements

Step 1: To get started with this repository, you'll need to follow these installation steps. Before proceeding, make sure you have Pytorch and Flash Attention installed. You can do this via pip using the following commands:

pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
pip install flash-attn==1.0.3.post

Please note that Flash Attention version 2 is not currently supported and may require manual modifications to the model file.

Step 2: Then install the rest of the required packages:

cd llmshearing
pip install -r requirement.txt

Step 3: Finally, install the llmshearing package in editable mode to make it accessible for your development environment:

pip install -e .

Data Preparation

Please refer to llmshearing/data for details on how to prepare data with Mosaicml's Streaming package.

Model Preparation

To utilize Hugging Face transformer models with Composer, you'll need to convert the model weights to the key format expected by Composer. Here's an example of how to convert the weights from the Hugging Face model 'llama2' into a compatible format for Composer:

# Define the Hugging Face model name and the output path
HF_MODEL_NAME=meta-llama/Llama-2-7b-hf
OUTPUT_PATH=models/Llama-2-7b-composer/state_dict.pt

# Create the necessary directory if it doesn't exist
mkdir -p $(dirname $OUTPUT_PATH)

# Convert the Hugging Face model to Composer key format
python3 -m llmshearing.utils.composer_to_hf save_hf_to_composer $HF_MODEL_NAME $OUTPUT_PATH

Additionally, you can use the following utility function to test the equivalence between the Hugging Face model and the converted Composer model:

MODEL_SIZE=7B
python3 -m llmshearing.utils.test_composer_hf_eq $HF_MODEL_NAME $OUTPUT_PATH $MODEL_SIZE

These functions exclusively work for LLaMA/LLaMA2 models. However, it should be straightforward to adapt them for use with other models such as Mistral-7B.

Sample Scripts for Pruning and Continued Pre-training

For pruning, you can reference an example script located in llmshearing/scripts/pruning.sh. In this script, you will need to make adjustments to incorporate data configurations, basic training configurations, pruning configurations and dynamic batch loading configurations.

Due to the relatively higher computational cost of pruning compared to continued pre-training, we halt training with the pruning objective after a specific number of steps (typically 3200 steps in all our experiments). Subsequently, we proceed with further pre-training of the pruned model. To ensure compatibility, it is necessary to convert the state dictionary keys of the model to align with a standard target model structure. Detailed instructions for this conversion can be found at Convert Pruned Model.

After completing the model conversion, you can continue with the pre-training of the pruned model. The process is similar to pre-train a standard model. To do this, you can refer to an example script located at llmshearing/scripts/continue_pretraining.sh. In this script, the pruning configurations are eliminated.

After training the model, you can use the conversion script to convert the composer model into a transformers model. Please refer to Section Convert Composer Model to Huggingface Model for more details.

Convert Pruned Model

Following the completion of training using llmshearing/scripts/pruning.sh, the saved models consist of the entire parameters of the source model, accompanied by a set of masks. We then act upon the masking variables by 1) removing the substructures where the masking variables are near $0$, 2) subsuming the masking variables into the model parameters by matrix-vector multiplcaition, and it result in a more compact model. Simultaneously, it becomes necessary to rename the weight keys so that they can be seamlessly loaded into a target model architecture, ensuring that the layer names are all consecutive.

MODEL_PATH=$MODEL_DIR/latest-rank0.pt
python3 -m llmshearing.utils.post_pruning_processing prune_and_save_model $MODEL_PATH

The pruned model will be saved in $(dirname $MODEL_PATH)/pruned-latest-rank0.pt.

Convert Composer Model to Huggingface Model

After training, if you'd like to use use huggingface for inference or fine-tuning, you may opt to transform your composer model into a Hugging Face model using the llmshearing/scripts/composer_to_hf.py script. Here's an example of how to use the script:

MODEL_PATH=$MODEL_DIR/latest-rank0.pt
OUTPUT_PATH=$MODEL_DIR/hf-latest_rank0
MODEL_CLASS=LlamaForCausalLM
HIDDEN_SIZE=2048
NUM_ATTENTION_HEADS=16
NUM_HIDDEN_LAYERS=24
INTERMEDIATE_SIZE=5504
MODEL_NAME=Sheared-Llama-1.3B

python3 -m llmshearing.utils.composer_to_hf save_composer_to_hf $MODEL_PATH $OUTPUT_PATH \
        model_class=${MODEL_CLASS} \
        hidden_size=${HIDDEN_SIZE} \
        num_attention_heads=${NUM_ATTENTION_HEADS} \
        num_hidden_layers=${NUM_HIDDEN_LAYERS} \
        intermediate_size=${INTERMEDIATE_SIZE} \
        num_key_value_heads=${NUM_ATTENTION_HEADS} \
        _name_or_path=${MODEL_NAME}

Please be aware that the parameter names mentioned here are tailored to Llama2's Hugging Face configurations and may differ when dealing with other model types.

Training Configurations

In this section, we provide an in-depth guide on configuring parameters within YAML configuration files for training. These configurations encompass several key aspects, including data setup, fundamental training settings, pruning settings, and dynamic data loading configurations.

Data configurations

  • data_local: The local directory containing the data.
  • eval_loader.dataset.split: For evaluation, provide the name of a combined split that includes data from all domains.
  • train_loader.dataset.split: When dynamic=True (please refer to the dynamic loading section) in the dynamic loading configuration, there's no need to set this value. However, if dynamic=False, you must specify a training split.

Basic training configurations

The basic training configurations largely follow the original Composer package. For comprehensive details on these configurations, please refer to Composer's official documentation. Here are some key training parameters to take note of:

  • max_duration: This parameter defines the maximum training duration and can be specified in either the number of steps (e.g., 3200ba) or epochs (e.g., 1ep). In our experiments, the pruning duration was set to 3200ba, and the continued pre-training duration was set to 48000ba.
  • save_interval: This parameter determines how frequently the model state is saved. We set it to 3200ba for both the pruning and continued pre-training stages..
  • t_warmup: This parameter specifies the duration of the learning rate warm-up for the learning rate scheduler. In the case of pruning, it is set to 320ba ($10%$ of training), while for continued pre-training, it is set to 1440ba ($3%$ of training).
  • optimizer.lr: This parameter defines the learning rate for the primary model parameters, with the default value being 1e-4.
  • max_seq_len: Following the Llama 2 training methodology, we accommodate a maximum sequence length of 4096.
  • device_train_microbatch_size: This parameter determines the batch size per device during training. For the pruning stage, we configure it to 4, whereas for continued pre-training, it is set to 16.
  • global_train_batch_size: This parameter specifies the global batch size across all GPUs during training. During the pruning stage, it is configured as 32, while for continued pre-training, it is increased to 256.
  • autoresume: This parameter can be enabled by setting it to true when resuming a run. However, it's important to note that while we have used it successfully during the continued pretraining stage, there is no guarantee of its compatibility with the pruning stage.

Due to computational constraints, an exhaustive hyperparameter search was not conducted, and there may exist better hyper-parameters for improved performance.

Pruning configurations

The pruning process allows pruning a source model to a specific target shape, and the script includes essential parameters such as:

  • from_model: This parameter specifies the source model size and corresponds to a config_file.
  • to_model: This parameter defines the target model size, and the source model will be pruned to match the target configuration.
  • optimizer.lag_lr: This parameter specifies the learning rate to learn the masking variables and Lagrangian multipliers during pruning. The default value is $1.0$.

The pruning-specific arguments are all grouped under model.l0_module:

  • model.l0_module.lagrangian_warmup_steps: In the initial warm-up phase, the pruning rate incrementally rises from 0 to reach the desired target value. The specific target value is determined by the predefined structure of the target model. It's important to note that this value might differ from the warm-up steps associated with learning rates. Typically, we allocate approximately 20% of the total number of steps for this pruning warm-up process.
  • model.l0_module.pruning_modules: By default, this setting prunes various aspects of the model, including the head, intermediate dimensions, hidden dimensions, and layers.
  • model.l0_module.eval_target_model: When set to true, the evaluation process assesses a submodel that exactly matches the target model's structure. If set to false, the evaluation process considers the current model, taking into account the masking values. Since the mask may take some time to converge to the target model shape, we evaluate based on the current model shape rather than the target structure during training.
  • model.l0_module.target_model.d_model: Specifies the hidden dimension of the target model.
  • model.l0_module.target_model.n_heads: Specifies the number of heads in the target model.
  • model.l0_module.target_model.n_layers: Specifies the number of layers in the target model.
  • model.l0_module.target_model.intermediate_size: Specifies the number of intermediate dimensions in the target model.

These parameters allow you to configure and control the pruning process according to your specific requirements.

Dynamic batch loading configurations

We extend Steaming's StreamingDataset in datasets/streaming_dataset.py to support loading data dynamically. The parameters for configuring dynamic batch loading are primarily defined within the DynamicLoadingCallback. Most of the following configurations can be specified in a YAML configuration file under the callbacks.data_loading section. Here's an explanation of each parameter:

  • callbacks.data_loading.dynamic: This boolean parameter determines whether dynamic data loading is enabled. When set to true, data is loaded dynamically from various domains or streams. If set to false, dynamic data loading is disabled.
  • callbacks.data_loading.set_names: Specify the domain names or stream names that will be used for dynamic data loading.
  • callbacks.data_loading.proportion: This parameter defines the initial data loading proportion for each domain or stream. The sum of all proportions must equal 1, indicating the relative weights of each source in the initial data loading configuration.
  • callbacks.data_loading.update_type: Choose the update type for adjusting the data loading proportions during training. There are two options
    • doremi: In this mode, the data loading proportions are updated using an exponential descent approach, similar to the method described in Doremi. This allows for adaptive adjustment of data loading proportions over time.
    • constant: Selecting this option keeps the data loading proportions constant throughout training. It's equivalent to disabling dynamic data loading.
  • callbacks.data_loading.target_loss: Specify the target validation loss for the training process. This target loss value should be calculated or predetermined before training begins. The loading proportions will be dynamically adjusted based on the difference between the model's current loss and the target loss. This adjustment helps guide the training process towards the desired performance level.
  • eval_interval: Determine how often evaluations are performed during training. If dynamic=True, the data loading proportion will be adjusted after each evaluation.

The code is designed to exclusively accommodate local data and does not support remote streaming data. Additionally, it currently only functions with a single worker for the dataloader and does not offer prefetch support. In our testing, this restriction has does not incur any additional compute overhead.

Throughput

Here is the throughout of running the pruning and continued pretraining step with A100 80GB GPUs. The throughput is quantified in terms of tokens processed per second. Please refer to the standard throughput of llm-foundry.

GPUs Throughput per Device Throughput
Pruning 7B 8 1844 14750
Pre-training 3B 16 4957 79306
Pre-training 1.3B 16 8684 138945

Future Work

Source models: While large models are undoubtedly powerful and have the potential to become stronger in the near future, we believe that small-scale models (those with fewer than 7 billion parameters) have untapped potential. However, there is little effort dedicated to making small models stronger, and our work pushes towards this goal. A natural extension of this work is to extend the codebase to prune

  • Stronger base models, such as Mistral-7B
  • Domain-specific language models such as code base models, including CodeLlama, and DeepSeek-Coder
  • Models from different scales. We mainly worked with 7B models due to computational constraints. It's unclear if pruning from larger models will be more beneficial.

To adapt the codebase to other models, one key component is to make sure that running the model with masks is equivalent to running the pruned model. We use llmshearing/utils/test_pruning.py to run such tests to ensure the correctness of the function prune_params in model files.

Data Sources: Please keep in mind that the performance of the resulting model is contingent not only on the pruning algorithm and the base model but also on the quality of the data. In our experiments, we mainly worked the RedPajama v1 data. However, here are some additional resources that could be considered for inclusion:

  • Dolma data, a 3T pre-training dataset including domains of CommonCrawl, C4, peS2o, The Stack, Project Gutenberg and Wikipedia
  • proof-pile-2, a 55 billion token dataset of mathematical and scientific documents.
  • RedPajama-v2, a 30T token pre-training dataset.

Bugs or Questions?

If you have any questions related to the code or the paper, feel free to email Mengzhou ([email protected]). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker!

Citation

Please cite our paper if you find the repo helpful in your work:

@article{xia2023sheared,
  title={Sheared llama: Accelerating language model pre-training via structured pruning},
  author={Xia, Mengzhou and Gao, Tianyu and Zeng, Zhiyuan and Chen, Danqi},
  journal={arXiv preprint arXiv:2310.06694},
  year={2023}
}

llm-shearing's People

Contributors

gaotianyu1350 avatar longyichen avatar xiamengzhou avatar zhiyuan-zeng avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

llm-shearing's Issues

Docker Request

Hi! Would anyone be willing to share a docker image where they can run the code successfully? I've tried installing in different environments but keep getting errors from mosaic composer.

cannot reshape array of size 4 into shape (1,newaxis,8)

โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ /mnt/workspace/workgroup/qianqin.rzr/LLM-Shearing/llmshearing/train.py:317 in โ”‚
โ”‚ โ”‚
โ”‚ 314 โ”‚ os.makedirs(save_dir, exist_ok=True) โ”‚
โ”‚ 315 โ”‚ torch.save(cfg, save_dir + "/config.pt") โ”‚
โ”‚ 316 โ”‚ โ”‚
โ”‚ โฑ 317 โ”‚ main(cfg) โ”‚
โ”‚ 318 โ”‚
โ”‚ 319 โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianqin.rzr/LLM-Shearing/llmshearing/train.py:301 in main โ”‚
โ”‚ โ”‚
โ”‚ 298 โ”‚ โ”‚ trainer.eval() โ”‚
โ”‚ 299 โ”‚ โ”‚
โ”‚ 300 โ”‚ print('Starting training...') โ”‚
โ”‚ โฑ 301 โ”‚ trainer.fit() โ”‚
โ”‚ 302 โ”‚ โ”‚
โ”‚ 303 โ”‚ print('Done.') โ”‚
โ”‚ 304 โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianli.myf/anaconda3/envs/rzr_llmshearing/lib/python3.9/site-packages/c โ”‚
โ”‚ omposer/trainer/trainer.py:1876 in fit โ”‚
โ”‚ โ”‚
โ”‚ 1873 โ”‚ โ”‚ โ”‚ self.state.scaler = ClosureGradScaler() if self._use_closures() else GradSca โ”‚
โ”‚ 1874 โ”‚ โ”‚ โ”‚
โ”‚ 1875 โ”‚ โ”‚ self.first_batch_complete = False โ”‚
โ”‚ โฑ 1876 โ”‚ โ”‚ self._train_loop() โ”‚
โ”‚ 1877 โ”‚ โ”‚
โ”‚ 1878 โ”‚ def close(self): โ”‚
โ”‚ 1879 โ”‚ โ”‚ """Shutdown the trainer. โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianli.myf/anaconda3/envs/rzr_llmshearing/lib/python3.9/site-packages/c โ”‚
โ”‚ omposer/trainer/trainer.py:2018 in _train_loop โ”‚
โ”‚ โ”‚
โ”‚ 2015 โ”‚ โ”‚ โ”‚ โ”‚ if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, โ”‚
โ”‚ 2016 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ dataloader.sampler.set_epoch(int(self.state.timestamp.epoch)) โ”‚
โ”‚ 2017 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚
โ”‚ โฑ 2018 โ”‚ โ”‚ โ”‚ โ”‚ for batch_idx, self.state.batch in enumerate(self._iter_dataloader(Train โ”‚
โ”‚ 2019 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ # Spin dataloader forward unless dataloader handles internally with โ”‚
โ”‚ 2020 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ if self.spin_dataloaders and 'train' not in self.state.dataset_resum โ”‚
โ”‚ 2021 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ self.state.timestamp.batch_in_epoch): โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianli.myf/anaconda3/envs/rzr_llmshearing/lib/python3.9/site-packages/c โ”‚
โ”‚ omposer/trainer/trainer.py:3024 in _iter_dataloader โ”‚
โ”‚ โ”‚
โ”‚ 3021 โ”‚ โ”‚ โ”‚ โ”‚ # [BEFORE/AFTER]_DATALOADER only runs while training โ”‚
โ”‚ 3022 โ”‚ โ”‚ โ”‚ โ”‚ if trainer_mode == TrainerMode.TRAIN: โ”‚
โ”‚ 3023 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ self.engine.run_event(Event.BEFORE_DATALOADER) โ”‚
โ”‚ โฑ 3024 โ”‚ โ”‚ โ”‚ โ”‚ batch = next(dataloader_iter) โ”‚
โ”‚ 3025 โ”‚ โ”‚ โ”‚ except StopIteration: โ”‚
โ”‚ 3026 โ”‚ โ”‚ โ”‚ โ”‚ # [BEFORE/AFTER]_DATALOADER only runs while training โ”‚
โ”‚ 3027 โ”‚ โ”‚ โ”‚ โ”‚ if trainer_mode == TrainerMode.TRAIN: โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianli.myf/anaconda3/envs/rzr_llmshearing/lib/python3.9/site-packages/t โ”‚
โ”‚ orch/utils/data/dataloader.py:633 in next โ”‚
โ”‚ โ”‚
โ”‚ 630 โ”‚ โ”‚ โ”‚ if self._sampler_iter is None: โ”‚
โ”‚ 631 โ”‚ โ”‚ โ”‚ โ”‚ # TODO(pytorch/pytorch#76750) โ”‚
โ”‚ 632 โ”‚ โ”‚ โ”‚ โ”‚ self._reset() # type: ignore[call-arg] โ”‚
โ”‚ โฑ 633 โ”‚ โ”‚ โ”‚ data = self._next_data() โ”‚
โ”‚ 634 โ”‚ โ”‚ โ”‚ self._num_yielded += 1 โ”‚
โ”‚ 635 โ”‚ โ”‚ โ”‚ if self._dataset_kind == _DatasetKind.Iterable and \ โ”‚
โ”‚ 636 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ self._IterableDataset_len_called is not None and \ โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianli.myf/anaconda3/envs/rzr_llmshearing/lib/python3.9/site-packages/t โ”‚
โ”‚ orch/utils/data/dataloader.py:677 in _next_data โ”‚
โ”‚ โ”‚
โ”‚ 674 โ”‚ โ”‚
โ”‚ 675 โ”‚ def _next_data(self): โ”‚
โ”‚ 676 โ”‚ โ”‚ index = self._next_index() # may raise StopIteration โ”‚
โ”‚ โฑ 677 โ”‚ โ”‚ data = self._dataset_fetcher.fetch(index) # may raise StopIteration โ”‚
โ”‚ 678 โ”‚ โ”‚ if self._pin_memory: โ”‚
โ”‚ 679 โ”‚ โ”‚ โ”‚ data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) โ”‚
โ”‚ 680 โ”‚ โ”‚ return data โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianli.myf/anaconda3/envs/rzr_llmshearing/lib/python3.9/site-packages/t โ”‚
โ”‚ orch/utils/data/_utils/fetch.py:32 in fetch โ”‚
โ”‚ โ”‚
โ”‚ 29 โ”‚ โ”‚ โ”‚ data = [] โ”‚
โ”‚ 30 โ”‚ โ”‚ โ”‚ for _ in possibly_batched_index: โ”‚
โ”‚ 31 โ”‚ โ”‚ โ”‚ โ”‚ try: โ”‚
โ”‚ โฑ 32 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ data.append(next(self.dataset_iter)) โ”‚
โ”‚ 33 โ”‚ โ”‚ โ”‚ โ”‚ except StopIteration: โ”‚
โ”‚ 34 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ self.ended = True โ”‚
โ”‚ 35 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ break โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianqin.rzr/LLM-Shearing/llmshearing/datasets/streaming_dataset.py:384 โ”‚
โ”‚ in iter โ”‚
โ”‚ โ”‚
โ”‚ 381 โ”‚ โ”‚ epoch, used_sample_ids = self._resume_incr_epoch(world) โ”‚
โ”‚ 382 โ”‚ โ”‚ โ”‚
โ”‚ 383 โ”‚ โ”‚ # Get this worker's partition of samples to process. โ”‚
โ”‚ โฑ 384 โ”‚ โ”‚ sample_ids_per_stream = self._get_work(world, epoch, used_sample_ids) โ”‚
โ”‚ 385 โ”‚ โ”‚ โ”‚
โ”‚ 386 โ”‚ โ”‚ # Currently only supports dynamically loading data from each domain for once. โ”‚
โ”‚ 387 โ”‚ โ”‚ # Issues could occur if one domain of data is used up. โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianqin.rzr/LLM-Shearing/llmshearing/datasets/streaming_dataset.py:338 โ”‚
โ”‚ in _get_work โ”‚
โ”‚ โ”‚
โ”‚ 335 โ”‚ โ”‚ โ”‚
โ”‚ 336 โ”‚ โ”‚ # Do expensive work that may use a lot of cores/memory just once, in the local l โ”‚
โ”‚ 337 โ”‚ โ”‚ if world.is_local_leader: โ”‚
โ”‚ โฑ 338 โ”‚ โ”‚ โ”‚ sample_ids_per_stream = generate_work(self, world, epoch, used_domain_ids) โ”‚
โ”‚ 339 โ”‚ โ”‚ โ”‚ shape_shms, data_shms = self._share_work(sample_ids_per_stream) โ”‚
โ”‚ 340 โ”‚ โ”‚ โ”‚ self._shared_barrier(world.workers_per_node) โ”‚
โ”‚ 341 โ”‚ โ”‚ else: โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianqin.rzr/LLM-Shearing/llmshearing/datasets/streaming_dataset.py:78 โ”‚
โ”‚ in generate_work โ”‚
โ”‚ โ”‚
โ”‚ 75 โ”‚ โ”‚ โ”‚ del reverse_mapping โ”‚
โ”‚ 76 โ”‚ โ”‚ โ”‚
โ”‚ 77 โ”‚ โ”‚ # check โ”‚
โ”‚ โฑ 78 โ”‚ โ”‚ stream_partition = get_partitions_orig(samples_in_stream, โ”‚
โ”‚ 79 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ dataset.num_canonical_nodes, world.num_no โ”‚
โ”‚ 80 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ world.ranks_per_node, world.workers_per_r โ”‚
โ”‚ 81 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ 0, used_stream_ids) โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianqin.rzr/LLM-Shearing/llmshearing/datasets/partition.py:116 in โ”‚
โ”‚ get_partitions_orig โ”‚
โ”‚ โ”‚
โ”‚ 113 โ”‚ โ”‚ underflow = ranks_per_node - overflow โ”‚
โ”‚ 114 โ”‚ โ”‚ last = ids[:, -ranks_per_node - underflow + 1:-ranks_per_node + 1] โ”‚
โ”‚ 115 โ”‚ โ”‚ ids = np.concatenate([ids, last], 1) โ”‚
โ”‚ โฑ 116 โ”‚ ids = ids.reshape(num_physical_nodes, -1, ranks_per_node) โ”‚
โ”‚ 117 โ”‚ โ”‚
โ”‚ 118 โ”‚ # Pad with -1 adequately for reshaping across workers. โ”‚
โ”‚ 119 โ”‚ # โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
ValueError: cannot reshape array of size 4 into shape (1,newaxis,8)

anyone meet this problem,ask for help

Finetuning using LoRA

Is it possible to finetune one of your checkpoints using LoRA on a dataset of our choice? If yes, how might we go about doing it?

meta-llama/Llama-2-7b-hf Model Preparation failed

(rzr_llmshearing) root@dsw70428-7485c78d87-pp4rr:/mnt/workspace/workgroup/qianqin.rzr/LLM-Shearing# sh modelprepare.sh
โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ /mnt/workspace/workgroup/qianli.myf/anaconda3/envs/rzr_llmshearing/lib/python3.9/runpy.py:197 in โ”‚
โ”‚ _run_module_as_main โ”‚
โ”‚ โ”‚
โ”‚ 194 โ”‚ main_globals = sys.modules["main"].dict โ”‚
โ”‚ 195 โ”‚ if alter_argv: โ”‚
โ”‚ 196 โ”‚ โ”‚ sys.argv[0] = mod_spec.origin โ”‚
โ”‚ โฑ 197 โ”‚ return _run_code(code, main_globals, None, โ”‚
โ”‚ 198 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ "main", mod_spec) โ”‚
โ”‚ 199 โ”‚
โ”‚ 200 def run_module(mod_name, init_globals=None, โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianli.myf/anaconda3/envs/rzr_llmshearing/lib/python3.9/runpy.py:87 in โ”‚
โ”‚ _run_code โ”‚
โ”‚ โ”‚
โ”‚ 84 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ loader = loader, โ”‚
โ”‚ 85 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ package = pkg_name, โ”‚
โ”‚ 86 โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ spec = mod_spec) โ”‚
โ”‚ โฑ 87 โ”‚ exec(code, run_globals) โ”‚
โ”‚ 88 โ”‚ return run_globals โ”‚
โ”‚ 89 โ”‚
โ”‚ 90 def _run_module_code(code, init_globals=None, โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianqin.rzr/LLM-Shearing/llmshearing/utils/composer_to_hf.py:107 in โ”‚
โ”‚ โ”‚
โ”‚ โ”‚
โ”‚ 104 if name == "main": โ”‚
โ”‚ 105 โ”‚ composer_model_path, output_path, other_args = sys.argv[1], sys.argv[2], sys.argv[3: โ”‚
โ”‚ 106 โ”‚ cli_cfg = om.from_cli(other_args) โ”‚
โ”‚ โฑ 107 โ”‚ save_composer_to_hf(composer_model_path, output_path, cli_cfg) โ”‚
โ”‚ 108 โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianqin.rzr/LLM-Shearing/llmshearing/utils/composer_to_hf.py:89 in โ”‚
โ”‚ save_composer_to_hf โ”‚
โ”‚ โ”‚
โ”‚ 86 def save_composer_to_hf(composer_model_path, output_path=None, model_config:om = None): โ”‚
โ”‚ 87 โ”‚ """ convert composer ckpt's weights to huggingface """ โ”‚
โ”‚ 88 โ”‚ โ”‚
โ”‚ โฑ 89 โ”‚ weights = torch.load(composer_model_path)["state"]["model"] โ”‚
โ”‚ 90 โ”‚ num_layers = get_layer_num_from_weights(weights) โ”‚
โ”‚ 91 โ”‚ keymap = get_key_map_from_composer_to_hf(num_layers) โ”‚
โ”‚ 92 โ”‚ hf_weights = {keymap[key]: weights[key] for key in weights if "rotary" not in key} โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianli.myf/anaconda3/envs/rzr_llmshearing/lib/python3.9/site-packages/t โ”‚
โ”‚ orch/serialization.py:791 in load โ”‚
โ”‚ โ”‚
โ”‚ 788 โ”‚ if 'encoding' not in pickle_load_args.keys(): โ”‚
โ”‚ 789 โ”‚ โ”‚ pickle_load_args['encoding'] = 'utf-8' โ”‚
โ”‚ 790 โ”‚ โ”‚
โ”‚ โฑ 791 โ”‚ with _open_file_like(f, 'rb') as opened_file: โ”‚
โ”‚ 792 โ”‚ โ”‚ if _is_zipfile(opened_file): โ”‚
โ”‚ 793 โ”‚ โ”‚ โ”‚ # The zipfile reader is going to advance the current file position. โ”‚
โ”‚ 794 โ”‚ โ”‚ โ”‚ # If we want to actually tail call to torch.jit.load, we need to โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianli.myf/anaconda3/envs/rzr_llmshearing/lib/python3.9/site-packages/t โ”‚
โ”‚ orch/serialization.py:271 in _open_file_like โ”‚
โ”‚ โ”‚
โ”‚ 268 โ”‚
โ”‚ 269 def _open_file_like(name_or_buffer, mode): โ”‚
โ”‚ 270 โ”‚ if _is_path(name_or_buffer): โ”‚
โ”‚ โฑ 271 โ”‚ โ”‚ return _open_file(name_or_buffer, mode) โ”‚
โ”‚ 272 โ”‚ else: โ”‚
โ”‚ 273 โ”‚ โ”‚ if 'w' in mode: โ”‚
โ”‚ 274 โ”‚ โ”‚ โ”‚ return _open_buffer_writer(name_or_buffer) โ”‚
โ”‚ โ”‚
โ”‚ /mnt/workspace/workgroup/qianli.myf/anaconda3/envs/rzr_llmshearing/lib/python3.9/site-packages/t โ”‚
โ”‚ orch/serialization.py:252 in init โ”‚
โ”‚ โ”‚
โ”‚ 249 โ”‚
โ”‚ 250 class _open_file(_opener): โ”‚
โ”‚ 251 โ”‚ def init(self, name, mode): โ”‚
โ”‚ โฑ 252 โ”‚ โ”‚ super().init(open(name, mode)) โ”‚
โ”‚ 253 โ”‚ โ”‚
โ”‚ 254 โ”‚ def exit(self, *args): โ”‚
โ”‚ 255 โ”‚ โ”‚ self.file_like.close() โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
IsADirectoryError: [Errno 21] Is a directory: '/mnt/workspace/workgroup/qianqin.rzr/model/llama-2-7b-hf'

these are my files in (rzr_llmshearing) root@dsw70428-7485c78d87-pp4rr:/mnt/workspace/workgroup/qianqin.rzr/LLM-Shearing# ls ../model/llama-2-7b-hf/
config.json LICENSE.txt model-00002-of-00002.safetensors pytorch_model-00001-of-00002.bin pytorch_model.bin.index.json Responsible-Use-Guide.pdf tokenizer_config.json tokenizer.model
generation_config.json model-00001-of-00002.safetensors model.safetensors.index.json pytorch_model-00002-of-00002.bin README.md special_tokens_map.json tokenizer.json USE_POLICY.md

Please what can I do to fix this error?

NotImplementedError: offload_to_cpu=True and NO_SHARD is not supported yet

This problem occurs when the pruning program is finished and saved.
This issue results in no way to save checkpoints normally

User
log(event, f'Running callback {type(cb).__na โ”‚
โ”‚ โฑ 468 โ”‚   โ”‚   โ”‚   โ”‚   cb.run_event(event, self.state, self.logger)             โ”‚
โ”‚   469 โ”‚                                                                        โ”‚
โ”‚   470 โ”‚   def _run_loggers(self, event: Union[Event, str]):                    โ”‚
โ”‚   471 โ”‚   โ”‚   loggers = [callback for callback in self.state.callbacks if isin โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/core/call โ”‚
โ”‚ back.py:96 in run_event                                                        โ”‚
โ”‚                                                                                โ”‚
โ”‚    93 โ”‚   โ”‚   โ”‚   logger (Logger): The logger.                                 โ”‚
โ”‚    94 โ”‚   โ”‚   """                                                              โ”‚
โ”‚    95 โ”‚   โ”‚   event_cb = getattr(self, event.value)                            โ”‚
โ”‚ โฑ  96 โ”‚   โ”‚   return event_cb(state, logger)                                   โ”‚
โ”‚    97 โ”‚                                                                        โ”‚
โ”‚    98 โ”‚   def init(self, state: State, logger: Logger) -> None:                โ”‚
โ”‚    99 โ”‚   โ”‚   """Called on the :attr:`.Event.INIT` event.                      โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/callbacks โ”‚
โ”‚ /checkpoint_saver.py:294 in batch_checkpoint                                   โ”‚
โ”‚                                                                                โ”‚
โ”‚   291 โ”‚   def batch_checkpoint(self, state: State, logger: Logger):            โ”‚
โ”‚   292 โ”‚   โ”‚   assert callable(self.save_interval)                              โ”‚
โ”‚   293 โ”‚   โ”‚   if self.save_interval(state, Event.BATCH_CHECKPOINT) and self.la โ”‚
โ”‚ โฑ 294 โ”‚   โ”‚   โ”‚   self._save_checkpoint(                                       โ”‚
โ”‚   295 โ”‚   โ”‚   โ”‚   โ”‚   state,                                                   โ”‚
โ”‚   296 โ”‚   โ”‚   โ”‚   โ”‚   logger,                                                  โ”‚
โ”‚   297 โ”‚   โ”‚   โ”‚   )                                                            โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/callbacks โ”‚
โ”‚ /checkpoint_saver.py:332 in _save_checkpoint                                   โ”‚
โ”‚                                                                                โ”‚
โ”‚   329 โ”‚   โ”‚   # save the checkpoint to the filename                            โ”‚
โ”‚   330 โ”‚   โ”‚   filename_with_placeholders = self.filename.format(state, is_deep โ”‚
โ”‚   331 โ”‚   โ”‚                                                                    โ”‚
โ”‚ โฑ 332 โ”‚   โ”‚   saved_path = checkpoint.save_checkpoint(                         โ”‚
โ”‚   333 โ”‚   โ”‚   โ”‚   state=state,                                                 โ”‚
โ”‚   334 โ”‚   โ”‚   โ”‚   filename=filename_with_placeholders,                         โ”‚
โ”‚   335 โ”‚   โ”‚   โ”‚   weights_only=self.weights_only,                              โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/utils/che โ”‚
โ”‚ ckpoint.py:761 in save_checkpoint                                              โ”‚
โ”‚                                                                                โ”‚
โ”‚   758 โ”‚   โ”‚   }                                                                โ”‚
โ”‚   759 โ”‚   else:                                                                โ”‚
โ”‚   760 โ”‚   โ”‚   state_dict = {                                                   โ”‚
โ”‚ โฑ 761 โ”‚   โ”‚   โ”‚   'state': state.state_dict(),                                 โ”‚
โ”‚   762 โ”‚   โ”‚   โ”‚   'rng': reproducibility.get_rng_state(),                      โ”‚
โ”‚   763 โ”‚   โ”‚   }                                                                โ”‚
โ”‚   764                                                                          โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/core/stat โ”‚
โ”‚ e.py:891 in state_dict                                                         โ”‚
โ”‚                                                                                โ”‚
โ”‚    888 โ”‚   โ”‚   โ”‚   if attribute_name == 'dataset_state':                       โ”‚
โ”‚    889 โ”‚   โ”‚   โ”‚   โ”‚   serialized_value = self._dataset_state_dict()           โ”‚
โ”‚    890 โ”‚   โ”‚   โ”‚   elif attribute_name == 'model':                             โ”‚
โ”‚ โฑ  891 โ”‚   โ”‚   โ”‚   โ”‚   serialized_value = self.get_model_state_dict()          โ”‚
โ”‚    892 โ”‚   โ”‚   โ”‚   elif attribute_name == 'optimizers':                        โ”‚
โ”‚    893 โ”‚   โ”‚   โ”‚   โ”‚   optimizer = ensure_tuple(attribute_value)[              โ”‚
โ”‚    894 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   0]  # Let's stop pretending. We don't support more  โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/core/stat โ”‚
โ”‚ e.py:868 in get_model_state_dict                                               โ”‚
โ”‚                                                                                โ”‚
โ”‚    865 โ”‚   โ”‚   """                                                             โ”‚
โ”‚    866 โ”‚   โ”‚   if self.fsdp_enabled and self.fsdp_state_dict_type is not None: โ”‚
โ”‚    867 โ”‚   โ”‚   โ”‚   with fsdp_state_dict_type_context(self.model, state_dict_ty โ”‚
โ”‚ โฑ  868 โ”‚   โ”‚   โ”‚   โ”‚   model_state_dict = self.model.state_dict()              โ”‚
โ”‚    869 โ”‚   โ”‚   else:                                                           โ”‚
โ”‚    870 โ”‚   โ”‚   โ”‚   model_state_dict = self.model.state_dict()                  โ”‚
โ”‚    871                                                                         โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/nn/modules/m โ”‚
โ”‚ odule.py:1818 in state_dict                                                    โ”‚
โ”‚                                                                                โ”‚
โ”‚   1815 โ”‚   โ”‚   self._save_to_state_dict(destination, prefix, keep_vars)        โ”‚
โ”‚   1816 โ”‚   โ”‚   for name, module in self._modules.items():                      โ”‚
โ”‚   1817 โ”‚   โ”‚   โ”‚   if module is not None:                                      โ”‚
โ”‚ โฑ 1818 โ”‚   โ”‚   โ”‚   โ”‚   module.state_dict(destination=destination, prefix=prefi โ”‚
โ”‚   1819 โ”‚   โ”‚   for hook in self._state_dict_hooks.values():                    โ”‚
โ”‚   1820 โ”‚   โ”‚   โ”‚   hook_result = hook(self, destination, prefix, local_metadat โ”‚
โ”‚   1821 โ”‚   โ”‚   โ”‚   if hook_result is not None:                                 โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/nn/modules/m โ”‚
โ”‚ odule.py:1815 in state_dict                                                    โ”‚
โ”‚                                                                                โ”‚
โ”‚   1812 โ”‚   โ”‚   if hasattr(destination, "_metadata"):                           โ”‚
โ”‚   1813 โ”‚   โ”‚   โ”‚   destination._metadata[prefix[:-1]] = local_metadata         โ”‚
โ”‚   1814 โ”‚   โ”‚                                                                   โ”‚
โ”‚ โฑ 1815 โ”‚   โ”‚   self._save_to_state_dict(destination, prefix, keep_vars)        โ”‚
โ”‚   1816 โ”‚   โ”‚   for name, module in self._modules.items():                      โ”‚
โ”‚   1817 โ”‚   โ”‚   โ”‚   if module is not None:                                      โ”‚
โ”‚   1818 โ”‚   โ”‚   โ”‚   โ”‚   module.state_dict(destination=destination, prefix=prefi โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/nn/modules/m โ”‚
โ”‚ odule.py:1722 in _save_to_state_dict                                           โ”‚
โ”‚                                                                                โ”‚
โ”‚   1719 โ”‚   โ”‚   โ”‚   โ”‚   module                                                  โ”‚
โ”‚   1720 โ”‚   โ”‚   """                                                             โ”‚
โ”‚   1721 โ”‚   โ”‚   for hook in self._state_dict_pre_hooks.values():                โ”‚
โ”‚ โฑ 1722 โ”‚   โ”‚   โ”‚   hook(self, prefix, keep_vars)                               โ”‚
โ”‚   1723 โ”‚   โ”‚                                                                   โ”‚
โ”‚   1724 โ”‚   โ”‚   for name, param in self._parameters.items():                    โ”‚
โ”‚   1725 โ”‚   โ”‚   โ”‚   if param is not None:                                       โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/utils/_conte โ”‚
โ”‚ xtlib.py:115 in decorate_context                                               โ”‚
โ”‚                                                                                โ”‚
โ”‚   112 โ”‚   @functools.wraps(func)                                               โ”‚
โ”‚   113 โ”‚   def decorate_context(*args, **kwargs):                               โ”‚
โ”‚   114 โ”‚   โ”‚   with ctx_factory():                                              โ”‚
โ”‚ โฑ 115 โ”‚   โ”‚   โ”‚   return func(*args, **kwargs)                                 โ”‚
โ”‚   116 โ”‚                                                                        โ”‚
โ”‚   117 โ”‚   return decorate_context                                              โ”‚
โ”‚   118                                                                          โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ โ”‚
โ”‚ fsdp/_state_dict_utils.py:669 in _pre_state_dict_hook                          โ”‚
โ”‚                                                                                โ”‚
โ”‚   666 โ”‚   โ”‚   StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,      โ”‚
โ”‚   667 โ”‚   โ”‚   StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,  โ”‚
โ”‚   668 โ”‚   }                                                                    โ”‚
โ”‚ โฑ 669 โ”‚   _pre_state_dict_hook_fn[fsdp_state._state_dict_type](                โ”‚
โ”‚   670 โ”‚   โ”‚   fsdp_state,                                                      โ”‚
โ”‚   671 โ”‚   โ”‚   module,                                                          โ”‚
โ”‚   672 โ”‚   โ”‚   *args,                                                           โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ โ”‚
โ”‚ fsdp/_state_dict_utils.py:271 in _full_pre_state_dict_hook                     โ”‚
โ”‚                                                                                โ”‚
โ”‚   268 โ”‚   in ``nn.Module``.                                                    โ”‚
โ”‚   269 โ”‚   """                                                                  โ”‚
โ”‚   270 โ”‚   _common_pre_state_dict_hook(module, fsdp_state)                      โ”‚
โ”‚ โฑ 271 โ”‚   _common_unshard_pre_state_dict_hook(                                 โ”‚
โ”‚   272 โ”‚   โ”‚   module,                                                          โ”‚
โ”‚   273 โ”‚   โ”‚   fsdp_state,                                                      โ”‚
โ”‚   274 โ”‚   โ”‚   offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu,     โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ โ”‚
โ”‚ fsdp/_state_dict_utils.py:143 in _common_unshard_pre_state_dict_hook           โ”‚
โ”‚                                                                                โ”‚
โ”‚   140 โ”‚   Performs the pre-state_dict tasks shared by all state_dict types tha โ”‚
โ”‚   141 โ”‚   ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_ โ”‚
โ”‚   142 โ”‚   """                                                                  โ”‚
โ”‚ โฑ 143 โ”‚   _enter_unshard_params_ctx(                                           โ”‚
โ”‚   144 โ”‚   โ”‚   module,                                                          โ”‚
โ”‚   145 โ”‚   โ”‚   fsdp_state,                                                      โ”‚
โ”‚   146 โ”‚   โ”‚   writeback=False,                                                 โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ โ”‚
โ”‚ fsdp/_state_dict_utils.py:109 in _enter_unshard_params_ctx                     โ”‚
โ”‚                                                                                โ”‚
โ”‚   106 โ”‚   โ”‚   offload_to_cpu=offload_to_cpu,                                   โ”‚
โ”‚   107 โ”‚   โ”‚   with_grads=with_grads,                                           โ”‚
โ”‚   108 โ”‚   )                                                                    โ”‚
โ”‚ โฑ 109 โ”‚   fsdp_state._unshard_params_ctx[module].__enter__()                   โ”‚
โ”‚   110                                                                          โ”‚
โ”‚   111                                                                          โ”‚
โ”‚   112 @no_type_check                                                           โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/contextlib.py:135 in __enter__   โ”‚
โ”‚                                                                                โ”‚
โ”‚   132 โ”‚   โ”‚   # they are only needed for recreation, which is not possible any โ”‚
โ”‚   133 โ”‚   โ”‚   del self.args, self.kwds, self.func                              โ”‚
โ”‚   134 โ”‚   โ”‚   try:                                                             โ”‚
โ”‚ โฑ 135 โ”‚   โ”‚   โ”‚   return next(self.gen)                                        โ”‚
โ”‚   136 โ”‚   โ”‚   except StopIteration:                                            โ”‚
โ”‚   137 โ”‚   โ”‚   โ”‚   raise RuntimeError("generator didn't yield") from None       โ”‚
โ”‚   138                                                                          โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ โ”‚
โ”‚ fsdp/_unshard_param_utils.py:171 in _unshard_fsdp_state_params                 โ”‚
โ”‚                                                                                โ”‚
โ”‚   168 โ”‚   This unshards the parameters for a single FSDP state ``state`` that  โ”‚
โ”‚   169 โ”‚   corresponds to ``module``.                                           โ”‚
โ”‚   170 โ”‚   """                                                                  โ”‚
โ”‚ โฑ 171 โ”‚   _validate_unshard_params_args(                                       โ”‚
โ”‚   172 โ”‚   โ”‚   state, writeback, rank0_only, offload_to_cpu, with_grads         โ”‚
โ”‚   173 โ”‚   )                                                                    โ”‚
โ”‚   174 โ”‚   torch.cuda.synchronize()                                             โ”‚
โ”‚                                                                                โ”‚
โ”‚ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ โ”‚
โ”‚ fsdp/_unshard_param_utils.py:140 in _validate_unshard_params_args              โ”‚
โ”‚                                                                                โ”‚
โ”‚   137 โ”‚   if offload_to_cpu and any(                                           โ”‚
โ”‚   138 โ”‚   โ”‚   not handle.uses_sharded_strategy for handle in state._handles    โ”‚
โ”‚   139 โ”‚   ):                                                                   โ”‚
โ”‚ โฑ 140 โ”‚   โ”‚   raise NotImplementedError(                                       โ”‚
โ”‚   141 โ”‚   โ”‚   โ”‚   "offload_to_cpu=True and NO_SHARD is not supported yet"      โ”‚
โ”‚   142 โ”‚   โ”‚   )                                                                โ”‚
โ”‚   143 โ”‚   if writeback and rank0_only:                                         โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
NotImplementedError: offload_to_cpu=True and NO_SHARD is not supported yet

There are related discussions about this problem on the Internet
huggingface/transformers#24874
huggingface/transformers#24874
, and it seems to have been fixed in the transformer library. Are you considering updating the version or adopting other fixes?

The implementation of dynamic batch loading code seems inconsistent with the pseudo-code in the paper

For example, truncating the loss difference to 0 does not seem to be implemented.
image

diff = torch.tensor(losses) - torch.tensor(self.target_loss)

And, what is the purpose of this line?

updated_domain_weights = (1-c) * updated_alpha + c / self.n_domains

ๅœจ่ฟ›่กŒBuilding trainerๆ—ถ๏ผŒ่ฎญ็ปƒไผšๅกไฝ๏ผ›

ไฝ ๅฅฝ๏ผŒๆˆ‘ไฝฟ็”จ็š„ๆ˜ฏๆ ทไพ‹ๆต‹่ฏ•้›†๏ผŒๆƒณ่ท‘้€šREADME. ไฝ†ๆ˜ฏๅ‘็Žฐ๏ผŒๅœจ่ฎญ็ปƒ็š„ๆ—ถๅ€™๏ผŒไผšๅกไฝ๏ผŒ็„ถๅŽ่ถ…ๆ—ถ๏ผ›
[batch=23/3200]:
Train time/batch: 22
Train time/sample: 198
Train time/batch_in_epoch: 6
Train time/sample_in_epoch: 54
Train time/token: 811008
Train time/token_in_epoch: 221184
Train metrics/train/cc_weight: 0.6700
Train metrics/train/github_weight: 0.0450
Train metrics/train/book_weight: 0.0450
Train metrics/train/stackexchange_weight: 0.0200
Train metrics/train/wiki_weight: 0.0450
Train metrics/train/arxiv_weight: 0.0250
Train metrics/train/c4-rp_weight: 0.1500
Train memory/current_allocated_mem: 36.8820
Train memory/current_active_mem: 36.8820
Train memory/current_inactive_mem: 0.1744
Train memory/current_reserved_mem: 55.9060
Train memory/peak_allocated_mem: 42.9380
Train memory/peak_active_mem: 42.9380
Train memory/peak_inactive_mem: 7.8742
Train memory/peak_reserved_mem: 55.9060
Train memory/alloc_retries: 0
Train metrics/train/expected_head_sparsity: 0.0039
Train metrics/train/target_head_sparsity: 0.0129
Train metrics/train/expected_intermediate_sparsity: 0.0039
Train metrics/train/target_intermediate_sparsity: 0.0128
Train metrics/train/expected_layer_sparsity: 0.0039
Train metrics/train/target_layer_sparsity: 0.0000
Train metrics/train/expected_hidden_sparsity: 0.0039
Train metrics/train/target_hidden_sparsity: 0.0129
Train metrics/train/expected_sparsity: 0.0117
Train metrics/train/target_sparsity: 0.0209
Train trainer/device_train_microbatch_size: 3
Train loss/train/total: 1.4801
Train loss/train/ce_loss: 1.4716
Train loss/train/lag_loss: 0.0085
Train metrics/train/LanguageCrossEntropy: 1.4716
Train metrics/train/Perplexity: 4.3561
Train metrics/train/cc_LanguageCrossEntropy: 1.1558
Train metrics/train/cc_count: 65
Train metrics/train/github_LanguageCrossEntropy: nan
Train metrics/train/github_count: 7
Train metrics/train/book_LanguageCrossEntropy: nan
Train metrics/train/book_count: 7
Train metrics/train/stackexchange_LanguageCrossEntropy: 2.1491
Train metrics/train/stackexchange_count: 3
Train metrics/train/wiki_LanguageCrossEntropy: 1.5306
Train metrics/train/wiki_count: 8
Train metrics/train/arxiv_LanguageCrossEntropy: nan
Train metrics/train/arxiv_count: 6
Train metrics/train/c4-rp_LanguageCrossEntropy: 1.6471
Train metrics/train/c4-rp_count: 111
Train throughput/batches_per_sec: 0.0914
Train throughput/samples_per_sec: 0.8223
Train throughput/device/batches_per_sec: 0.0305
Train throughput/device/samples_per_sec: 0.2741
Train throughput/tokens_per_sec: 3368.2385
Train throughput/device/tokens_per_sec: 1122.7462
Train throughput/flops_per_sec: 157886485043818.8125
Train throughput/device/flops_per_sec: 52628828347939.6016
Train throughput/device/mfu: 0.1687
Train time/train: 0.0709
Train time/val: 0.0000
Train time/total: 0.0709
Train lr-DecoupledAdamW/group0: 0.0000
Train lr-DecoupledAdamW/group1: 0.0688
Train lr-DecoupledAdamW/group2: -0.0688
[E ProcessGroupNCCL.cpp:828] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=3777, OpType=_ALLGATHER_BASE, Timeout(ms)=1800000) ran for 1802129 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
terminate called after throwing an instance of 'std::runtime_error'
what(): [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=3777, OpType=_ALLGATHER_BASE, Timeout(ms)=1800000) ran for 1802129 milliseconds before timing out.

When should we apply hidden_z?

I notice that hidden_z is applied in almost every module in every layer, and I'm curious about whether it will result in issue like gradient vanishing and exploding? And, will it have a large influence on the magnitude of the last hidden state, as the same scale is repeatedly multiplied?

Pruning crash at iteration 592.

@xiamengzhou
[batch=592/3200]
Train time/batch: 591
Train time/sample: 18912
Train time/batch_in_epoch: 591
Train time/sample_in_epoch: 18912
Train time/token: 77463552
Train time/token_in_epoch: 77463552
Train metrics/train/cc_weight: 0.2292
Train metrics/train/github_weight: 0.0121
Train metrics/train/book_weight: 0.0220
Train metrics/train/stackexchange_weight: 0.0059
Train metrics/train/wiki_weight: 0.5933
Train metrics/train/arxiv_weight: 0.0038
Train metrics/train/c4-rp_weight: 0.1336
Train memory/current_allocated_mem: 14.6140
Train memory/current_active_mem: 14.6140
Train memory/current_inactive_mem: 1.9265
Train memory/current_reserved_mem: 43.4220
Train memory/peak_allocated_mem: 28.0710
Train memory/peak_active_mem: 28.0710
Train memory/peak_inactive_mem: 11.7290
Train memory/peak_reserved_mem: 43.4220
Train memory/alloc_retries: 0
Train metrics/train/expected_head_sparsity: 0.3583
Train metrics/train/target_head_sparsity: 0.3463
Train metrics/train/expected_intermediate_sparsity: 0.3196
Train metrics/train/target_intermediate_sparsity: 0.3436
Train metrics/train/expected_layer_sparsity: 0.0039
Train metrics/train/target_layer_sparsity: 0.0000
Train metrics/train/expected_hidden_sparsity: 0.4266
Train metrics/train/target_hidden_sparsity: 0.3463
Train metrics/train/expected_sparsity: 0.6188
Train metrics/train/target_sparsity: 0.5616
Train trainer/device_train_microbatch_size: 4
Train loss/train/total: 3.5578
Train loss/train/ce_loss: 2.8953
Train loss/train/lag_loss: 0.6625
Train metrics/train/LanguageCrossEntropy: 2.8953
Train metrics/train/Perplexity: 18.0886
Train metrics/train/cc_LanguageCrossEntropy: 3.0387
Train metrics/train/cc_count: 9884
Train metrics/train/github_LanguageCrossEntropy: nan
Train metrics/train/github_count: 652
Train metrics/train/book_LanguageCrossEntropy: nan
Train metrics/train/book_count: 712
Train metrics/train/stackexchange_LanguageCrossEntropy: nan
Train metrics/train/stackexchange_count: 236
Train metrics/train/wiki_LanguageCrossEntropy: 2.7964
Train metrics/train/wiki_count: 4011
Train metrics/train/arxiv_LanguageCrossEntropy: nan
Train metrics/train/arxiv_count: 267
Train metrics/train/c4-rp_LanguageCrossEntropy: 3.1243
Train metrics/train/c4-rp_count: 3182
Train throughput/batches_per_sec: 0.1329
Train throughput/samples_per_sec: 4.2523
Train throughput/device/batches_per_sec: 0.0166
Train throughput/device/samples_per_sec: 0.5315
Train throughput/tokens_per_sec: 17417.3748
Train throughput/device/tokens_per_sec: 2177.1719
Train throughput/flops_per_sec: 816440956730026.0000
Train throughput/device/flops_per_sec: 102055119591253.2500
Train time/train: 1.2715
Train time/val: 0.6538
Train time/total: 1.9253
Traceback (most recent call last):
File "/llm-shearing//llmshearing/train.py", line 317, in
main(cfg)
File "/llm-shearing//llmshearing/train.py", line 301, in main
trainer.fit()
File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 18
76, in fit
self._train_loop()
File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 20
18, in _train_loop
for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)):
File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 30
24, in _iter_dataloader
batch = next(dataloader_iter)
File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line
630, in next
data = self._next_data()
File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line
674, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", li
ne 32, in fetch
data.append(next(self.dataset_iter))
File "/llm-shearing/llmshearing/datasets/streaming_dataset.py", line 392,
in iter
domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id]
IndexError: index 552 is out of bounds for axis 0 with size 552

Reproduce as follow:
prepare data for prunnig as paper said.
then execuate following command
/bin/bash llmshearing/scripts/prunning.sh

Drive dress error

The Google Drive dress is error, I can not open it, would you give me a new dress, ths

image

AssertionError: Currently only supports dynamic loading from each domain for once.

When I use a single node, 8*A100 80G configuration, I find that an error occurs:

LLM-Shearing/llmshearing/datasets/streaming_da โ”‚
โ”‚ taset.py:46 in generate_work                                                                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    43 โ”‚   โ”‚   List[List[int]]: The epoch for each domain of data (num physical nodes,            โ”‚
โ”‚    44 โ”‚   โ”‚   ranks per node, workers per rank, batches per worker, batch size).                 โ”‚
โ”‚    45 โ”‚   """                                                                                    โ”‚
โ”‚ โฑ  46 โ”‚   assert epoch == 0, "Currently only supports dynamic loading from each domain for onc   โ”‚
โ”‚    47 โ”‚   # Ensure that num_canonical_nodes has been set.                                        โ”‚
โ”‚    48 โ”‚   if dataset.num_canonical_nodes is None:                                                โ”‚
โ”‚    49 โ”‚   โ”‚   raise RuntimeError(f'`num_canonical_nodes` can never be None. ' +                  โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
AssertionError: Currently only supports dynamic loading from each domain for once.

If I delete "assert epoch == 0, "Currently only supports dynamic loading from each domain for once.", i will cause another error in

# Currently only supports dynamically loading data from each domain for once. 
        # Issues could occur if one domain of data is used up. 
        while True:
            proportion = self.proportion
            stream_id = np.random.choice(range(self.num_streams), 1, p=proportion)[0].item()
            domain_sample_id = sample_ids_per_stream[stream_id]
            domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \
                                % self.samples_per_stream[stream_id]]
            self.used_num_samples_per_stream[stream_id] += 1
            yield self[domain_sample_id]
---
IndexError: index 24 is out of bounds for axis 0 with size 24

If i add "if world.is_local_leader and epoch==0:"
SharedMemory in _attach_work will come error:

โ”‚   304 โ”‚   โ”‚   โ”‚   # Load the generated epoch shape from shared memory.                           โ”‚
โ”‚   305 โ”‚   โ”‚   โ”‚   name = _get_path(self._shm_prefix_int, EPOCH_SHAPE + f"_{stream_id}")          โ”‚
โ”‚   306 โ”‚   โ”‚   โ”‚   size = ndim * np.int64().nbytes                                                โ”‚
โ”‚ โฑ 307 โ”‚   โ”‚   โ”‚   shape_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=Fa   โ”‚
โ”‚   308 โ”‚   โ”‚   โ”‚   shape = tuple(np.ndarray(5, buffer=shape_shm.buf, dtype=np.int64))             โ”‚
โ”‚   309 โ”‚   โ”‚   โ”‚                                                                                  โ”‚
โ”‚   310 โ”‚   โ”‚   โ”‚   # Attach to the generated epoch data in shared memory.      
---
FileNotFoundError: [Errno 2] No such file or directory: '/000000_epoch_shape_0'
Exception ignored in atexit callback: <function Engine._close at 0x7fe6dc766a70>

Flash-attn dependency issues

You said that Flash Attention version 2 is not currently supported and may require manual modifications to the model file. but In requirement.txt, you still have flash-attn2 installed. Does this cause some incompatibility issues?

Dynamic Batch Loading v.s. Domain Reweighting

image

we can see from Table 3 that dynamic batch loading results in down-sampling CC & Github and up-sampling Book & C4. I wonder if I re-sample the pretraining data according to this recipe ahead-of-time, and do pruning and continue-training on the re-sampled data directly, will it reach the same performance ?

Sample.py error when sampling stackexchage

SLURM_ARRAY_TASK_ID=5 python sample.py --target_dir mds_redpajama --tokenized_dir hadoop/tokenized_redpajama An error will occur. The reason is that stackexchage only has one file, while sample.py requires multiple files to be split

> /root/personal-code/LLM-Shearing/llmshearing/data/sample.py(69)<module>()
-> folder_eval_target = args.eval_seq
(Pdb) c
Traceback (most recent call last):
  File "/root/personal-code/LLM-Shearing/llmshearing/data/sample.py", line 70, in <module>
    num_sample_each_file = max(1, folder_eval_target // len(selected) + 1)
ZeroDivisionError: integer division or modulo by zero```

Any updates on the code?

Would love to experiment with LLM shearing, is the progress on finalizing the code still ongoing?

TypeError: load_data() missing 1 required positional argument: 'tokenizer_name'

Hello I was following the command in data/Readme.md

python3 -m llmshearing.data.merge_data
--input_dir $INPUT_DIR
--output_dir $OUTPUT_DIR
--output_split eval_merge
--split_names domain1 domain2

However I found two issues,
one is at line 31, you guys did not import sys, and the sys.argv[2:] should be change to sys.argv[1:], otherwise the input dir will always be none.
another is at line 50, It seems there should be a default value or guide for 'tokneizer_name'? Since I got TypeError: load_data() missing 1 required positional argument: 'tokenizer_name'

Thank you very much in advance

KV head count on princeton-nlp/Sheared-LLaMA-1.3B-ShareGPT ?

Hi! While testing the model with vLLM vllm-project/vllm#1913 I found out that the KV head count seemed strange, why is it 32 instead of the 16 like the base sheared llama?

Is it safe for me to just change that value to 16 and use the model that way?

Thanks for the great work! I've dreamed about just pruning LLMs for speculative decoding instead of training a separate model for a long time! :)

Small typo on Table 2

image

Pythia-1.4B seems to be better than OPT-1.3B and Sheared-LLaMA-1.3B on LogiQA. :)

missmatch shape

ไธ‹้ขๆ˜ฏๆˆ‘็š„config.pt็š„ๆ–‡ไปถๅ†…ๅฎน

{'data_local': '/workspace/LLM-shearing/LLM-Shearing/llmshearing/data/mds_sample_redpajama/for_prune', 'data_remote': None, 'tokenizer_name': '/workspace/LLM-shearing/models/Llama-2-7b-hf', 'max_seq_len': 4096, 'global_seed': 17, 'run_name': 'llama2_7b_pruning_scaling_constant_to2.7b_sl4096', 'model': {'name': 'mosaic_llama2_7b', 'path': '/workspace/LLM-shearing/models/Llama-2-7b-composer/state_dict.pt', 'init_device': 'cpu', 'tokenizer_name': '${tokenizer_name}', 'd_model': 4096, 'n_heads': 32, 'n_layers': 32, 'intermediate_size': 11008, 'max_seq_len': '${max_seq_len}', 'vocab_size': 32000, 'init_std': 0.02, 'attn_pdrop': 0.0, 'resid_pdrop': 0.0, 'emb_pdrop': 0.0, 'attn_impl': 'flash', 'rms_norm_eps': 1e-05, 'l0_module': {'start_sparsity': 0.0, 'target_sparsity': 0.5, 'pruning_modules': ['head', 'intermediate', 'layer', 'hidden'], 'lagrangian_warmup_steps': '640ba', 'target_model': {'d_model': 2560, 'n_layers': 32, 'n_heads': 20, 'intermediate_size': 6912, 'vocab_size': 32000}, 'eval_target_model': False}}, 'tokenizer': {'type': 'hftokenizer', 'args': {'tokenizer_name': '${tokenizer_name}', 'max_seq_len': '${max_seq_len}'}}, 'train_loader': {'name': 'text', 'dataset': {'local': '${data_local}', 'remote': '${data_remote}', 'split': 'wikipedia', 'shuffle': True, 'tokenizer_name': '${tokenizer_name}', 'max_seq_len': '${max_seq_len}', 'shuffle_seed': '${global_seed}', 'is_uint16': True}, 'drop_last': True, 'num_workers': 0, 'prefetch_factor': None, 'persistent_workers': False}, 'eval_loader': {'name': 'text', 'dataset': {'local': '/workspace/LLM-shearing/LLM-Shearing/llmshearing/data/mds_sample_redpajama/eval', 'remote': '${data_remote}', 'split': 'eval_merge', 'shuffle': False, 'tokenizer_name': '${tokenizer_name}', 'max_seq_len': '${max_seq_len}', 'shuffle_seed': '${global_seed}', 'is_uint16': True}, 'drop_last': False, 'num_workers': 8}, 'scheduler': {'name': 'cosine_with_warmup', 't_warmup': '320ba', 'alpha_f': 0.1}, 'optimizer': {'name': 'decoupled_adamw', 'lr': 0.0001, 'betas': [0.9, 0.95], 'eps': 1e-08, 'weight_decay': 0.0, 'lag_lr': 1.0}, 'algorithms': {'gradient_clipping': {'clipping_type': 'norm', 'clipping_threshold': 1.0}}, 'max_duration': '3200ba', 'eval_interval': '50ba', 'eval_subset_num_batches': 1000, 'global_train_batch_size': 8, 'seed': '${global_seed}', 'device_eval_batch_size': 2, 'device_train_microbatch_size': 4, 'precision': 'amp_bf16', 'fsdp_config': {'sharding_strategy': 'FULL_SHARD', 'mixed_precision': 'DEFAULT', 'activation_checkpointing': True, 'activation_cpu_offload': False, 'verbose': False}, 'progress_bar': False, 'log_to_console': True, 'console_log_interval': '1ba', 'callbacks': {'speed_monitor': {'window_size': 10}, 'memory_monitor': {}, 'lr_monitor': {}, 'data_loading': {'dynamic': True, 'update_type': 'constant', 'proportion': [0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], 'set_names': ['cc', 'github', 'book', 'stackexchange', 'wiki', 'arxiv', 'c4-rp'], 'target_loss': None}}, 'loggers': {'wandb': {'project': 'pruning', 'name': '${run_name}', 'entity': 'pruning', 'init_kwargs': {'mode': 'offline', 'dir': '/workspace/LLM-shearing/models/llama2_7b_pruning_scaling_constant_to2.7b_sl4096'}}}, 'save_interval': '3200ba', 'save_folder': '/workspace/LLM-shearing/models/llama2_7b_pruning_scaling_constant_to2.7b_sl4096', 'eval_first': False, 'autoresume': False}

ๆˆ‘ๅœจๆ‰ง่กŒcomposer_to_hf ๆ—ถๆœ‰2ไธช้—ฎ้ข˜๏ผš
ไธ€ไธชๆ˜ฏๅ‘็Žฐ่ฎญ็ปƒไน‹ๅŽ็š„ๆจกๅž‹็ป“ๆžœๅญ˜ๅœจkeyError๏ผš l0_moduleๆ‰พไธๅˆฐ็š„้—ฎ้ข˜๏ผŒไปฃ็ ๅœจ
num_layers = get_layer_num_from_weights(weights) keymap = get_key_map_from_composer_to_hf(num_layers) hf_weights = {keymap[key]: weights[key] for key in weights if "rotary" not in key }
่ฟ™ๆฎตไปฃ็ ไธญๆ˜ฏ่ฏด่ฟ‡ๆปคๆŽ‰ๆ‰€ๆœ‰็š„ๅŒ…ๅซ rotary ๆƒ้‡๏ผŒไฝ†ๆ˜ฏ้—ฎ้ข˜ๆ˜ฏweightsไธญๆœ‰l0_module็›ธๅ…ณ็š„ๆƒ้‡ไฟกๆฏ๏ผŒๅฏผ่‡ดไธ่ƒฝ่ฝฌๅˆฐhfไธŠ๏ผŒไปŽ่€ŒๆŠฅ้”™๏ผŒๆˆ‘ๅฐ่ฏ•ๅฐ†l0็›ธๅ…ณ็š„weihgtๅนฒๆŽ‰๏ผŒ็„ถๅŽไฟๅญ˜ไธ‹ๆฅ็š„ๆจกๅž‹ๅˆ่ฏดmissmatch
1ใ€ๅŽปๆŽ‰l0็›ธๅ…ณๆƒ้‡ๆ˜ฏๅฆๅฏ่กŒ๏ผŸ
2ใ€ๆˆ‘training็š„้…็ฝฎๆ–‡ไปถๆ˜ฏๅฆๅญ˜ๅœจ้—ฎ้ข˜๏ผŸ
ๆ„Ÿ่ฐขๆ”ฏๆŒ๏ผ›

Composer Model Transform problems encountered when shearing Pythia 1.4b

Pythia helps study the effects of pruning at different LM scales.
Noticed that you provided composer_pythia.py, so I tried running this experiment on 1.4b scale pythia. I'm having some problems.

The first is the writing of the pythia_to_composer conversion function for several settings I made.
The keymap here is what I inferred based on your composer_pythia.py

def get_gpt_key_map_from_hf_to_composer(num_layers):
    """ get the keymap from hf to composer """
    key_map = {}
    key_map.update({"gpt_neox.embed_in.weight": "model.transformer.wte.weight",
                    "gpt_neox.final_layer_norm.weight": "model.transformer.ln_f.weight",
                    "gpt_neox.final_layer_norm.bias": "model.transformer.ln_f.bias",
                    "embed_out.weight": "model.transformer.output.weight",
                    "embed_out.bias": "model.transformer.output.bias"})
    for i in range(num_layers):
        key_map.update({
                        f"gpt_neox.layers.{i}.input_layernorm.weight": f"model.transformer.blocks.{i}.ln_1.weight",
                        f"gpt_neox.layers.{i}.post_attention_layernorm.weight": f"model.transformer.blocks.{i}.ln_2.weight",
                        f"gpt_neox.layers.{i}.attention.query_key_value.weight": f"model.transformer.blocks.{i}.attn.query_key_value.weight",
                        f"gpt_neox.layers.{i}.attention.dense.weight": f"model.transformer.blocks.{i}.attn.out_proj.weight",
                        f"gpt_neox.layers.{i}.mlp.dense_h_to_4h.weight": f"model.transformer.blocks.{i}.mlp.up_proj.weight",
                        f"gpt_neox.layers.{i}.mlp.dense_4h_to_h.weight": f"model.transformer.blocks.{i}.mlp.down_proj.weight",
                        f"gpt_neox.layers.{i}.input_layernorm.bias": f"model.transformer.blocks.{i}.ln_1.bias",
                        f"gpt_neox.layers.{i}.post_attention_layernorm.bias": f"model.transformer.blocks.{i}.ln_2.bias",
                        f"gpt_neox.layers.{i}.attention.query_key_value.bias": f"model.transformer.blocks.{i}.attn.query_key_value.bias",
                        f"gpt_neox.layers.{i}.attention.dense.bias": f"model.transformer.blocks.{i}.attn.out_proj.bias",
                        f"gpt_neox.layers.{i}.mlp.dense_h_to_4h.bias": f"model.transformer.blocks.{i}.mlp.up_proj.bias",
                        f"gpt_neox.layers.{i}.mlp.dense_4h_to_h.bias": f"model.transformer.blocks.{i}.mlp.down_proj.bias",
                        f"gpt_neox.layers.{i}.attention.rotary_emb.inv_freq": f"model.transformer.blocks.{i}.attn.rotary_emb.inv_freq",
                       })
    return key_map

In test_composer_hf_equal.py

def construct_example_cfg(model_size, path=None, add_l0_module=False):
    """ construct example cfg for mosaicml llama models """
    if model_size == "1.4b":
        cfg = om.create({"name": "mosaic_pythia_1", "init_device": "cpu", "d_model": 2048, 
                         "n_heads": 16, "n_layers": 24, "intermediate_size": 8192,
                         "rotary_pct": 0.25, "rotary_emb_base": 10000
                         })

    # add default values
    cfg = om.merge(cfg, om.create({"max_seq_len": 2048, "vocab_size": 50304, "init_std": 0.02, "attn_pdrop": 0.0, "resid_pdrop": 0.0, "emb_pdrop": 0.0, "attn_impl": "norm", "layer_norm_eps": 1e-5}))
    if add_l0_module:
        cfg["l0_module"] = {"start_sparsity": 0, "target_sparsity": 0.6, "pruning_modules": ["head", "head_layer", "mlp", "intermediate", "hidden"], "lagrangian_warmup_steps": "320ba"}
    return cfg

final result

(Pdb) composer_model
ComposerMosaicPythia(
  (model): PythiaModel(
    (transformer): ModuleDict(
      (wte): PythiaEmbedding(50304, 2048)
      (blocks): ModuleList(
        (0-23): 24 x PythiaBlock(
          (ln_1): CoFiLayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (attn): PythiaAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (rotary_emb): RotaryEmbedding()
          )
          (ln_2): CoFiLayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): PythiaMLP(
            (down_proj): Linear(in_features=8192, out_features=2048, bias=True)
            (up_proj): Linear(in_features=2048, out_features=8192, bias=True)
          )
        )
      )
      (output): Linear(in_features=2048, out_features=50304, bias=False)
      (ln_f): CoFiLayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    )
  )
)
(Pdb) hf_model
GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 2048)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-23): 24 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
          (dense): Linear(in_features=2048, out_features=2048, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
          (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          (act): GELUActivation()
        )
      )
    )
    (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
  )
  (embed_out): Linear(in_features=2048, out_features=50304, bias=False)
)

The structure of the two models looks very different. I don't know if this is an issue with the structural organization in composer_pythia.py. Finally, I tested the output of the two models, and there was a huge difference.

Will you disclose Python-related experimental code? Do you need me to contribute relevant code?

KeyError: 'state'

Hello,
Thank you for your study about pruning.

When I converted a pruned model to HuggingFace model, I have an error about key.

MODEL_PATH='models/Llama-2-7b-composer/pruned-state_dict.pt'
OUTPUT_PATH='models/Llama-2-7b-composer/hf-state_dict'
MODEL_CLASS='LlamaForCausalLM'
HIDDEN_SIZE=2048
NUM_ATTENTION_HEADS=16
NUM_HIDDEN_LAYERS=24
INTERMEDIATE_SIZE=5504
MODEL_NAME='Sheared-Llama-1.3B'

!python3 -m llmshearing.utils.composer_to_hf $MODEL_PATH $OUTPUT_PATH \
model_class=$MODEL_CLASS \
hidden_size=$HIDDEN_SIZE \
num_attention_heads=$NUM_ATTENTION_HEADS \
num_hidden_layers=$NUM_HIDDEN_LAYERS \
intermediate_size=$INTERMEDIATE_SIZE \
num_key_value_heads=$NUM_ATTENTION_HEADS \
_name_or_path=$MODEL_NAME

โ”‚ /tf/LLM-Shearing/llmshearing/utils/composer_to_hf.py:108 in <module>         โ”‚
โ”‚                                                                              โ”‚
โ”‚   105 if __name__ == "__main__":                                             โ”‚
โ”‚   106 โ”‚   composer_model_path, output_path, other_args = sys.argv[1], sys.ar โ”‚
โ”‚   107 โ”‚   cli_cfg = om.from_cli(other_args)                                  โ”‚
โ”‚ โฑ 108 โ”‚   save_composer_to_hf(composer_model_path, output_path, cli_cfg)     โ”‚
โ”‚   109 โ”‚   #save_hf_to_composer(composer_model_path, output_path)             โ”‚
โ”‚   110                                                                        โ”‚
โ”‚                                                                              โ”‚
โ”‚ /tf/LLM-Shearing/llmshearing/utils/composer_to_hf.py:90 in                   โ”‚
โ”‚ save_composer_to_hf                                                          โ”‚
โ”‚                                                                              โ”‚
โ”‚    87 def save_composer_to_hf(composer_model_path, output_path=None, model_c โ”‚
โ”‚    88 โ”‚   """ convert composer ckpt's weights to huggingface """             โ”‚
โ”‚    89 โ”‚                                                                      โ”‚
โ”‚ โฑ  90 โ”‚   weights = torch.load(composer_model_path)["state"]["model"]        โ”‚
โ”‚    91 โ”‚   num_layers = get_layer_num_from_weights(weights)                   โ”‚
โ”‚    92 โ”‚   keymap = get_key_map_from_composer_to_hf(num_layers)               โ”‚
โ”‚    93 โ”‚   hf_weights = {keymap[key]: weights[key] for key in weights if "rot โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
KeyError: 'state'

Please what can I do to fix this error?

Thank you.

Repeated assignment in l0_module.py

In https://github.com/princeton-nlp/LLM-Shearing/blob/main/llmshearing/models/l0_module.py

def param_init_fn(self, module):
        """ Initialize the parameters for masking variables. """
        mean = math.log(1 - self.droprate_init) - math.log(self.droprate_init)
        mean = 5
        if isinstance(module, nn.Parameter):
            module.data.normal_(mean, 1e-2)
        else:
            for tensor in module.parameters():
                tensor.data.normal_(mean, 1e-2)
 

mean = math.log(1 - self.droprate_init) - math.log(self.droprate_init)
mean = 5

Repeated assignment, which may not be expected behavior
For what reason should this parameter be assigned to 5?

Metric Scores and NQ Evaluation

Hi there, thanks very much for such an amazing work!
Currently, I'm trying to reproduce the results from the paper, but I have the following problems:

  • Metric scores: lm-eval-harness would return acc and acc_norm at the same time. Do you have a preference to use one of them?
  • NQ evaluation: I've checked the evaluation codebase and your starter scripts, but have no clues about the evaluation on the Natural Questions (NQ) dataset (it is not implemented in lm-eval-harness yet).
    Did you write an evaluation script, or adopt an existing repo to get the scores?

Thank you so much for your time and response~

License is missing

Dear authors,
Please add an Apache 2.0 or MIT license so that it is an open-source project. As it stands, users are not allowed to modify or distribute your code because there is no license available.

wiki proportion finally dominates at the end of the pruning stage

pruning_wiki_dynamic

the script of pruning as follow:

# Specify $PROJ_DIR in scripts/launch.sh and scripts/srun_launch.sh if using slurm

test=True

from_model=7b # source model size
to_model=3b # target model size
config_file=${PROJ_DIR}/llmshearing/configs/llama2/${from_model}.yaml
path=$MODEL_PATH/llama-2-7b-composer/state_dict.pt

# data setup
data_local=${DATA_DIR}

# basic setup
max_seq_len=4096
device_train_microbatch_size=4
global_train_batch_size=32
device_eval_batch_size=8

# learning setup
lr=1e-4 # learning rate for the main parameters
# max_duration=3200ba # 0.42B tokens
# save_interval=3200ba # save in the end
# t_warmup=320ba # 10% learning rate warmup 

max_duration=3000ba # 0.39B tokens
save_interval=3000ba # save in the end
t_warmup=300ba # 10% learning rate warmup 

# dynamic loading setup
dynamic=True
set_names=[cc,github,book,stackexchange,wiki,arxiv,c4-rp] # domain names
proportion=[0.67,0.045,0.045,0.02,0.045,0.025,0.15] # initial proportion of RP, make sure that the sum(proportion) = 1
# doremi: update weights with exponential descent
# constant: keep the weights constant
update_type=doremi 
if [[ $to_model == 1.3b ]]; then
    target_loss=[1.9643,0.7459,2.1393,1.6117,1.7590,1.4449,2.1251] # 1.3b predicted loss from scaling law
else
    target_loss=[1.8712,0.6883,2.0325,1.5353,1.6297,1.3560,2.0328] # 2.7b predicted loss from scaling law
fi
eval_split_name=../eval/eval_merge # eval on all domains
eval_target_model=false # evaluate on the current model, not the target model, otherwise the loss will be inaccurate
eval_interval=50ba # eval every 50 batches and update the loading proportion


# pruning setup
lag_lr=1.0 # learning rate or l0_module
lagr_warmup=640ba # 20% sparsity warmup
if [[ $to_model == 1.3b ]]; then
    target_d_model=2048; target_n_heads=16; target_n_layers=24; target_intermediate_size=5504
elif [[ $to_model == 3b ]]; then
    target_d_model=2560; target_n_heads=20; target_n_layers=32; target_intermediate_size=6912
fi

# save directroy
run_name=llama2_${from_model}_pruning_scaling_${update_type}_to${to_model}_sl${max_seq_len}
save_dir=${OUTPUT_DIR}/${run_name}
wandb_dir=${save_dir} # save locally

if [[ $test == True ]]; then t=00-01:00:00; else t=01-00:00:00; fi

# Run in bash, it will automatically use resources available in the current environment
# composer $TRAIN_SCRIPT \

# Run with slurm    
# sbatch -p cli \
#     --job-name ${run_name} \
#     --nodes=4 \
#     --gpus-per-node=2 \
#     --mem=512gb \
#     --cpus-per-task=8 \
#     --time $t \
composer $TRAIN_SCRIPT \
    $config_file \
    run_name=${run_name} \
    data_local=${data_local} \
    eval_loader.dataset.split=${eval_split_name} \
    global_train_batch_size=${global_train_batch_size} \
    device_train_microbatch_size=${device_train_microbatch_size} \
    device_eval_batch_size=${device_eval_batch_size} \
    max_seq_len=${max_seq_len} \
    max_duration=${max_duration} \
    eval_first=false \
    scheduler.t_warmup=${t_warmup} \
    save_folder=${save_dir} \
    loggers.wandb.init_kwargs.dir=${wandb_dir} \
    eval_interval=${eval_interval} \
    save_interval=${save_interval} \
    optimizer.lr=${lr} \
    optimizer.lag_lr=${lag_lr} \
    model.path=${path} \
    model.l0_module.lagrangian_warmup_steps=${lagr_warmup} \
    model.l0_module.pruning_modules='[head,intermediate,layer,hidden]' \
    model.l0_module.eval_target_model=${eval_target_model} \
    model.l0_module.target_model.d_model=${target_d_model} \
    model.l0_module.target_model.n_heads=${target_n_heads} \
    model.l0_module.target_model.n_layers=${target_n_layers} \
    model.l0_module.target_model.intermediate_size=${target_intermediate_size} \
    callbacks.data_loading.dynamic=${dynamic} \
    callbacks.data_loading.set_names=${set_names} \
    callbacks.data_loading.proportion=${proportion} \
    callbacks.data_loading.update_type=${update_type} \
    callbacks.data_loading.target_loss=${target_loss} \
    train_loader.num_workers=0 \
    train_loader.prefetch_factor=null \
    train_loader.persistent_workers=false \
    autoresume=true

Error running CheckpointSaver.close(). Skipping CheckpointSaver.post_close()

batch=336/5000]:
Train time/batch: 335
Train time/sample: 85760
Train time/batch_in_epoch: 335
Train time/sample_in_epoch: 85760
Train time/token: 351272960
Train time/token_in_epoch: 351272960
Train metrics/train/academic_en_weight: 0.0733
Train metrics/train/book_en_weight: 0.2521
Train metrics/train/code_weight: 0.0122
Train metrics/train/qa_en_weight: 0.0900
Train metrics/train/webtext_en_weight: 0.4453
Train metrics/train/wiki_en_weight: 0.1271
Train memory/current_allocated_mem: 17.7600
Train memory/current_active_mem: 17.7600
Train memory/current_inactive_mem: 3.8283
Train memory/current_reserved_mem: 74.4430
Train memory/peak_allocated_mem: 60.7340
Train memory/peak_active_mem: 61.1380
Train memory/peak_inactive_mem: 27.1850
Train memory/peak_reserved_mem: 74.4430
Train memory/alloc_retries: 0
Train trainer/device_train_microbatch_size: 16
Train loss/train/total: 1.8199
Train loss/train/ce_loss: 1.8199
Train metrics/train/LanguageCrossEntropy: 1.8199
Train metrics/train/Perplexity: 6.1711
Train metrics/train/academic_en_LanguageCrossEntropy: 1.3244
Train metrics/train/academic_en_count: 9893
Train metrics/train/book_en_LanguageCrossEntropy: 1.8898
Train metrics/train/book_en_count: 20179
Train metrics/train/code_LanguageCrossEntropy: 0.8853
Train metrics/train/code_count: 4807
Train metrics/train/qa_en_LanguageCrossEntropy: 1.4436
Train metrics/train/qa_en_count: 10816
Train metrics/train/webtext_en_LanguageCrossEntropy: 2.0363
Train metrics/train/webtext_en_count: 27327
Train metrics/train/wiki_en_LanguageCrossEntropy: 1.5175
Train metrics/train/wiki_en_count: 12994
Train throughput/batches_per_sec: 0.0204
Train throughput/samples_per_sec: 5.2330
Train throughput/device/batches_per_sec: 0.0026
Train throughput/device/samples_per_sec: 0.6541
Train throughput/tokens_per_sec: 21434.4781
Train throughput/device/tokens_per_sec: 2679.3098
Train throughput/flops_per_sec: 1004697191366202.0000
Train throughput/device/flops_per_sec: 125587148920775.2500
Train time/train: 4.5453
Train time/val: 0.1311
Train time/total: 4.6764
Train lr-DecoupledAdamW/group0: 0.0001
Error running CheckpointSaver.close(). Skipping CheckpointSaver.post_close().
Traceback (most recent call last):
File "/home/pai/lib/python3.9/site-packages/composer/core/engine.py", line 527, in _close
callback.close(state, logger)
File "/home/pai/lib/python3.9/site-packages/composer/callbacks/checkpoint_saver.py", line 310, in close
self._save_checkpoint(
File "/home/pai/lib/python3.9/site-packages/composer/callbacks/checkpoint_saver.py", line 332, in _save_checkpoint
saved_path = checkpoint.save_checkpoint(
File "/home/pai/lib/python3.9/site-packages/composer/utils/checkpoint.py", line 761, in save_checkpoint
'state': state.state_dict(),
File "/home/pai/lib/python3.9/site-packages/composer/core/state.py", line 891, in state_dict
serialized_value = self.get_model_state_dict()
File "/home/pai/lib/python3.9/site-packages/composer/core/state.py", line 868, in get_model_state_dict
model_state_dict = self.model.state_dict()
File "/home/pai/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1818, in state_dict
module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
File "/home/pai/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1815, in state_dict
self._save_to_state_dict(destination, prefix, keep_vars)
File "/home/pai/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1722, in _save_to_state_dict
hook(self, prefix, keep_vars)
File "/home/pai/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/pai/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 669, in _pre_state_dict_hook
_pre_state_dict_hook_fn[fsdp_state._state_dict_type](
File "/home/pai/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 271, in _full_pre_state_dict_hook
_common_unshard_pre_state_dict_hook(
File "/home/pai/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 143, in _common_unshard_pre_state_dict_hook
_enter_unshard_params_ctx(
File "/home/pai/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 109, in _enter_unshard_params_ctx
fsdp_state._unshard_params_ctx[module].enter()
File "/home/pai/lib/python3.9/contextlib.py", line 119, in enter
return next(self.gen)
File "/home/pai/lib/python3.9/site-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 186, in _unshard_fsdp_state_params
assert (
AssertionError: Expects the handle training to be IDLE but got HandleTrainingState.BACKWARD_PRE
Stack (most recent call last):
File "/home/pai/lib/python3.9/site-packages/composer/core/engine.py", line 483, in del
self.close()
File "/home/pai/lib/python3.9/site-packages/composer/core/engine.py", line 512, in close
self._close(self.state, self.logger)
File "/home/pai/lib/python3.9/site-packages/composer/core/engine.py", line 529, in _close

Has anyone encountered this issue before? Why do I always get an error when I train with eight cards on a single machine for over 300 batchs, even though I definitely have enough training data?

Avoid OOM using deepspeed zero-stage

When pruning or continue training, I got OOM error even with batch size 1 & 8 GPUs.
I'm trying to use deepspeed zero stage or FSDP to avoid OOM. However, it seems your code isn't compatible well with deepspeed.
Could you help me with these OOM errors?

Release sheared model without re-training?

Hello, will you release sheared models without re-training? Currently, there are only some graphs showing the averaged ACC of each training checkpoint. I think it would be instructive if we can evaulate the un-trained model!

ShearedCodeLLama

Hi! I am working on a copilot backend and, even though I am using a GPTQ quant of codellama7b, it is still eating lots of VRAM
DeepSeek coder seems to have severe issues understanding fill in the middle

I wanted to ask if you plan on also shearing CodeLlama? :)

Question about ComposerMosaicLlama.forward

Hello guys, This is not an issue but my question. I am confused about the forward function in ComposerMosaicLlama. Why the pruning parameters (parameters which "_z") can be read from batch? I can not find which class is preparing for those. In my understanding, those pruning parameters should be all in l0_module?

 def forward(self, batch):
        input_ids = batch["input_ids"]
        key_padding_mask = batch["attention_mask"].bool() if "attention_mask" in batch else None
        pruned_steps = batch.get("pruned_steps", None)
        if pruned_steps is not None:
            pruned_steps = pruned_steps[0].item()
        zs = {key: batch[key] for key in batch if "_z" in key} # why those parameters coming from batch?
        model_output = self.model(
            input_ids=input_ids, key_padding_mask=key_padding_mask, pruned_steps=pruned_steps, **zs
        )
        return model_output

Thank you very much!

How much compute will this take?

Hi,
If I want to make a 1B/3B model for Mistral, do you know approximately how many dollars I'll have to spend in compute, and whether I can do it on a consumer GPU? Thanks!

Missing index.json in dataset shared on drive

Hello,

Thank you for sharing the dataset used for pruning.
Trying to use that dataset (by setting DATA_DIR in pruning.sh) results in the below error however -

Building train loader...
โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚                          /LLM-Shearing/llmshearing/train.py:317 in <module>                      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   314 โ”‚   os.makedirs(save_dir, exist_ok=True)                                                   โ”‚
โ”‚   315 โ”‚   torch.save(cfg, save_dir + "/config.pt")                                               โ”‚
โ”‚   316 โ”‚                                                                                          โ”‚
โ”‚ โฑ 317 โ”‚   main(cfg)                                                                              โ”‚
โ”‚   318                                                                                            โ”‚
โ”‚   319                                                                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚                          /LLM-Shearing/llmshearing/train.py:201 in main                          โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   198 โ”‚                                                                                          โ”‚
โ”‚   199 โ”‚   # Dataloaders                                                                          โ”‚
โ”‚   200 โ”‚   print('Building train loader...')                                                      โ”‚
โ”‚ โฑ 201 โ”‚   train_loader = build_text_dataloader(cfg.train_loader,                                 โ”‚
โ”‚   202 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚    cfg.device_train_batch_size,                      โ”‚
โ”‚   203 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚    cfg.callbacks.data_loading.dynamic,               โ”‚
โ”‚   204 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚    cfg.callbacks.data_loading.set_names,             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚                          /LLM-Shearing/llmshearing/datasets/load_text_dataloader.py:36 in        โ”‚
โ”‚ build_text_dataloader                                                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    33 โ”‚   """                                                                                    โ”‚
โ”‚    34 โ”‚                                                                                          โ”‚
โ”‚    35 โ”‚   if dynamic:                                                                            โ”‚
โ”‚ โฑ  36 โ”‚   โ”‚   dataset = TextDynamicStreamingDataset(local=cfg.dataset.local,                     โ”‚
โ”‚    37 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚     max_seq_len=cfg.dataset.max_seq_len,         โ”‚
โ”‚    38 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚     batch_size=device_batch_size,                โ”‚
โ”‚    39 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚     shuffle=cfg.dataset.get(                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚                          /LLM-Shearing/llmshearing/datasets/streaming_dataset.py:415 in __init__ โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   412 โ”‚   โ”‚   โ”‚   โ”‚    is_uint16: bool = False):                                                 โ”‚
โ”‚   413 โ”‚   โ”‚                                                                                      โ”‚
โ”‚   414 โ”‚   โ”‚   # Build Dataset                                                                    โ”‚
โ”‚ โฑ 415 โ”‚   โ”‚   super().__init__(local=local,                                                      โ”‚
โ”‚   416 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚    shuffle=shuffle,                                                  โ”‚
โ”‚   417 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚    shuffle_seed=shuffle_seed,                                        โ”‚
โ”‚   418 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚    num_canonical_nodes=num_canonical_nodes,                          โ”‚
โ”‚                                                                                                  โ”‚
โ”‚                          /LLM-Shearing/llmshearing/datasets/streaming_dataset.py:114 in __init__ โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   111 โ”‚   โ”‚   โ”‚   โ”‚    proportion: List[float] = None) -> None:                                  โ”‚
โ”‚   112 โ”‚   โ”‚                                                                                      โ”‚
โ”‚   113 โ”‚   โ”‚   streams = [Stream(local=local, split=set_name, repeat=1.0) for set_name in set_n   โ”‚
โ”‚ โฑ 114 โ”‚   โ”‚   super().__init__(streams=streams,                                                  โ”‚
โ”‚   115 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚    split=None,                                                       โ”‚
โ”‚   116 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚    num_canonical_nodes=num_canonical_nodes,                          โ”‚
โ”‚   117 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚    batch_size=batch_size,                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚                          /lib/python3.10/site-packages/streaming/base/dataset.py:443 in    โ”‚
โ”‚ __init__                                                                                         โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    440 โ”‚   โ”‚   self.sample_offset_per_stream = np.zeros(self.num_streams, np.int64)              โ”‚
โ”‚    441 โ”‚   โ”‚   self.samples_per_stream = np.zeros(self.num_streams, np.int64)                    โ”‚
โ”‚    442 โ”‚   โ”‚   for stream_id, stream in enumerate(self.streams):                                 โ”‚
โ”‚ โฑ  443 โ”‚   โ”‚   โ”‚   stream_shards = stream.get_shards(world)                                      โ”‚
โ”‚    444 โ”‚   โ”‚   โ”‚   num_stream_samples = sum(map(len, stream_shards))                             โ”‚
โ”‚    445 โ”‚   โ”‚   โ”‚   if not num_stream_samples:                                                    โ”‚
โ”‚    446 โ”‚   โ”‚   โ”‚   โ”‚   index_filename = os.path.join(stream.local, stream.split, get_index_base  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚                                /lib/python3.10/site-packages/streaming/base/stream.py:437 in     โ”‚
โ”‚ get_shards                                                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   434 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   os.rename(tmp_filename, filename)                                      โ”‚
โ”‚   435 โ”‚   โ”‚   โ”‚   โ”‚   else:                                                                      โ”‚
โ”‚   436 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   if not os.path.exists(filename):                                       โ”‚
โ”‚ โฑ 437 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   raise RuntimeError(f'No `remote` provided, but local file {filen   โ”‚
โ”‚   438 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚      'does not exist either')                        โ”‚
โ”‚   439 โ”‚   โ”‚   โ”‚   else:                                                                          โ”‚
โ”‚   440 โ”‚   โ”‚   โ”‚   โ”‚   wait_for_file_to_exist(                                                    โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
RuntimeError: No `remote` provided, but local file /for_prune/cc/index.json does not exist either

If this is not the correct way to use this shared dataset, could you let me know what's the recommend way to use this shared dataset to reproduce the results? Or perhaps could you upload the dataset including the index.json files?

Can Sheared-LLaMA beat OpenLLaMA v2 significantly with the same amount of compute ?

image

We can see from the picture that Sheared-LLaMA is on par with OpenLLaMA v2 with 1/32 compute, which is quite impressive. But from my experience, knowledge distillation and pruning is very data efficient indeed, but can not surpass pre-training from scatch significantly due to the size limit of the target model. I wonder if you train Sheared-LLaMA-2.7B longer, is it possible to surpass Open-LLaMA-3B-v2 by 4~5% on average downstream performance ?

Scaling Law for predicted loss

Hi! Thanks for finally releasing the code. I've been trying to shear Yi-34B (after modifying it to be identical in architecture and tokenizer to Llama2) down to 20B. In the pruning.sh script, there's a target_loss that needs to be specified. What scaling law is it based on?

Use without flash-attn?

Can I use this without Flash Attention? I really want to do this on my M1 MacBook or AMD GPU. Thanks!

Train metrics/train/github_LanguageCrossEntropy: nan

During training stage of prunning, Train metrics/train/github_LanguageCrossEntropy is nan, is that normal?

[batch=189/3200]:
Train time/batch: 188
Train time/sample: 6016
Train time/batch_in_epoch: 188
Train time/sample_in_epoch: 6016
Train time/token: 24641536
Train time/token_in_epoch: 24641536
Train metrics/train/cc_weight: 0.6176
Train metrics/train/github_weight: 0.0408
Train metrics/train/book_weight: 0.0441
Train metrics/train/stackexchange_weight: 0.0168
Train metrics/train/wiki_weight: 0.0861
Train metrics/train/arxiv_weight: 0.0189
Train metrics/train/c4-rp_weight: 0.1757
Train memory/current_allocated_mem: 14.6140
Train memory/current_active_mem: 14.6140
Train memory/current_inactive_mem: 1.9258
Train memory/current_reserved_mem: 43.4220
Train memory/peak_allocated_mem: 28.0710
Train memory/peak_active_mem: 28.0710
Train memory/peak_inactive_mem: 11.7290
Train memory/peak_reserved_mem: 43.4220
Train memory/alloc_retries: 0
Train metrics/train/expected_head_sparsity: 0.0132
Train metrics/train/target_head_sparsity: 0.1102
Train metrics/train/expected_intermediate_sparsity: 0.0057
Train metrics/train/target_intermediate_sparsity: 0.1093
Train metrics/train/expected_layer_sparsity: 0.0039
Train metrics/train/target_layer_sparsity: 0.0000
Train metrics/train/expected_hidden_sparsity: 0.1882
Train metrics/train/target_hidden_sparsity: 0.1102
Train metrics/train/expected_sparsity: 0.1981
Train metrics/train/target_sparsity: 0.1786
Train trainer/device_train_microbatch_size: 4
Train loss/train/total: 3.9353
Train loss/train/ce_loss: 2.3241
Train loss/train/lag_loss: 1.6112
Train metrics/train/LanguageCrossEntropy: 2.3241
Train metrics/train/Perplexity: 10.2176
Train metrics/train/cc_LanguageCrossEntropy: 2.2752
Train metrics/train/cc_count: 3991
Train metrics/train/github_LanguageCrossEntropy: nan
Train metrics/train/github_count: 276
Train metrics/train/book_LanguageCrossEntropy: nan

sample data generate name

HI,after run bash sample_all_domains.sh I can't get the dirs as you list,eg:sample1ใ€sample2ใ€eval_merge.my dirs is:
.
โ”œโ”€โ”€ eval
โ”‚ โ”œโ”€โ”€ arxiv
โ”‚ โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”‚ โ”œโ”€โ”€ book
โ”‚ โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”‚ โ”œโ”€โ”€ c4-rp
โ”‚ โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”‚ โ”œโ”€โ”€ cc
โ”‚ โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”‚ โ”œโ”€โ”€ github
โ”‚ โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”‚ โ”œโ”€โ”€ stackexchange
โ”‚ โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”‚ โ””โ”€โ”€ wiki
โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”œโ”€โ”€ for_ft
โ”‚ โ”œโ”€โ”€ arxiv
โ”‚ โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”‚ โ”œโ”€โ”€ book
โ”‚ โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”‚ โ”œโ”€โ”€ c4-rp
โ”‚ โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”‚ โ”œโ”€โ”€ cc
โ”‚ โ”‚ โ””โ”€โ”€ index.json
โ”‚ โ”œโ”€โ”€ github
โ”‚ โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”‚ โ”œโ”€โ”€ stackexchange
โ”‚ โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”‚ โ””โ”€โ”€ wiki
โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ””โ”€โ”€ shard.00000.mds
โ””โ”€โ”€ for_prune
โ”œโ”€โ”€ arxiv
โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”œโ”€โ”€ book
โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”œโ”€โ”€ c4-rp
โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”œโ”€โ”€ cc
โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”œโ”€โ”€ github
โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ””โ”€โ”€ shard.00000.mds
โ”œโ”€โ”€ stackexchange
โ”‚ โ”œโ”€โ”€ index.json
โ”‚ โ””โ”€โ”€ shard.00000.mds
โ””โ”€โ”€ wiki
โ”œโ”€โ”€ index.json
โ””โ”€โ”€ shard.00000.mds

LanguageCrossEntropy logs nan when bash pruning.sh

When I conducted the pruning experiment, I simply configured the data set and made no other changes. I found that it seems that the metric is not updated, and the log repeatedly prints loss as nan, as follows๏ผš

[metric][batch=0]: time/epoch: 2365 
[metric][batch=0]: metrics/train/LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Perplexity: nan 
[metric][batch=0]: metrics/train/ArXiv_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/ArXiv_count: 0 
[metric][batch=0]: metrics/train/Books_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Books_count: 0 
[metric][batch=0]: metrics/train/Wikipedia_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Wikipedia_count: 0 
[metric][batch=0]: time/epoch: 2366 
[metric][batch=0]: metrics/train/LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Perplexity: nan 
[metric][batch=0]: metrics/train/ArXiv_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/ArXiv_count: 0 
[metric][batch=0]: metrics/train/Books_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Books_count: 0 
[metric][batch=0]: metrics/train/Wikipedia_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Wikipedia_count: 0 
.... repeat

I set pdb breakpoints in metric's update function and composerllama's update_mteric, but these breakpoints were not executed.
The input data seems to be intact. I tested trainloader and train.eval and everything is normal. However, this problem inevitably occurs in train.fit.

The Setting of pruning:

# learning setup
lr=1e-4 # learning rate for the main parameters
max_duration=3200ba # 0.42B tokens
save_interval=3200ba # save in the end
t_warmup=320ba # 10% learning rate warmup 

# dynamic loading setup
dynamic=True
# set_names=[cc,github,book,stackexchange,wiki,arxiv,c4-rp] # domain names
set_names=[ArXiv,Books,Wikipedia] # domain names
# proportion=[0.67,0.045,0.045,0.02,0.045,0.025,0.15] # initial proportion of RP, make sure that the sum(proportion) = 1
proportion=[0.4,0.3,0.3] 
# doremi: update weights with exponential descent
# constant: keep the weights constant
update_type=doremi 
if [[ $to_model == 1.3b ]]; then
    # target_loss=[1.9643,0.7459,2.1393,1.6117,1.7590,1.4449,2.1251] # 1.3b predicted loss from scaling law
    target_loss=[1.4449,2.1393,1.7590]
else
    # target_loss=[1.8712,0.6883,2.0325,1.5353,1.6297,1.3560,2.0328] # 2.7b predicted loss from scaling law
    target_loss=[1.3560,2.0325,1.6297]
fi
eval_split_name=eval_merge # eval on all domains
eval_target_model=false # evaluate on the current model, not the target model, otherwise the loss will be inaccurate
eval_interval=50ba # eval every 50 batches and update the loading proportion


# pruning setup
lag_lr=1.0 # learning rate or l0_module
lagr_warmup=640ba # 20% sparsity warmup
if [[ $to_model == 1.3b ]]; then
    target_d_model=2048; target_n_heads=16; target_n_layers=24; target_intermediate_size=5504
elif [[ $to_model == 3b ]]; then
    target_d_model=2560; target_n_heads=20; target_n_layers=32; target_intermediate_size=6912
fi

composer $TRAIN_SCRIPT \
    $config_file \
    run_name=${run_name} \
    data_local=${data_local} \
    eval_loader.dataset.split=${eval_split_name} \
    global_train_batch_size=${global_train_batch_size} \
    device_train_microbatch_size=${device_train_microbatch_size} \
    device_eval_batch_size=${device_eval_batch_size} \
    max_seq_len=${max_seq_len} \
    max_duration=${max_duration} \
    eval_first=false \
    scheduler.t_warmup=${t_warmup} \
    save_folder=${save_dir} \
    loggers.wandb.init_kwargs.dir=${wandb_dir} \
    eval_interval=${eval_interval} \
    save_interval=${save_interval} \
    optimizer.lr=${lr} \
    optimizer.lag_lr=${lag_lr} \
    model.l0_module.lagrangian_warmup_steps=${lagr_warmup} \
    model.l0_module.pruning_modules='[head,intermediate,layer,hidden]' \
    model.l0_module.eval_target_model=${eval_target_model} \
    model.l0_module.target_model.d_model=${target_d_model} \
    model.l0_module.target_model.n_heads=${target_n_heads} \
    model.l0_module.target_model.n_layers=${target_n_layers} \
    model.l0_module.target_model.intermediate_size=${target_intermediate_size} \
    callbacks.data_loading.dynamic=${dynamic} \
    callbacks.data_loading.set_names=${set_names} \
    callbacks.data_loading.proportion=${proportion} \
    callbacks.data_loading.update_type=${update_type} \
    callbacks.data_loading.target_loss=${target_loss} \
    train_loader.num_workers=0 \
    train_loader.prefetch_factor=null \
    train_loader.persistent_workers=false \
    autoresume=false

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.