Giter Club home page Giter Club logo

mixtral-offloading's Introduction

Mixtral offloading

This project implements efficient inference of Mixtral-8x7B models.

How does it work?

In summary, we achieve efficient inference of Mixtral-8x7B models through a combination of techniques:

  • Mixed quantization with HQQ. We apply separate quantization schemes for attention layers and experts to fit the model into the combined GPU and CPU memory.
  • MoE offloading strategy. Each expert per layer is offloaded separately and only brought pack to GPU when needed. We store active experts in a LRU cache to reduce GPU-RAM communication when computing activations for adjacent tokens.

For more detailed information about our methods and results, please refer to our tech-report.

Running

To try this demo, please use the demo notebook: ./notebooks/demo.ipynb or Open In Colab

For now, there is no command-line script available for running the model locally. However, you can create one using the demo notebook as a reference. That being said, contributions are welcome!

Work in progress

Some techniques described in our technical report are not yet available in this repo. However, we are actively working on adding support for them in the near future.

Some of the upcoming features are:

  • Support for other quantization methods
  • Speculative expert prefetching

mixtral-offloading's People

Contributors

dvmazur avatar eltociear avatar justheuristic avatar lavawolfiee avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mixtral-offloading's Issues

exl2

using exl2 2.4 you can run mixtral on colab, did you give it a try ?

Can it run on multi-GPU?

Thanks for your contributions. I would like to know whether it can be deployed on multi-GPU to allow the use of more VRAM?

Is it possible to finetune this on a custom dataset?

Hi there,

Just wondering is it possible to fine tune this model on a custom dataset? If so, are there any examples/code?

Many thanks for any help, and for this amazing model, I'm finding it works really well!

4bit-3bit model produces gibberish when plugged into demo

Hello, I'm attempting to run the demo with the 4bit-3bit model. I updated the names of the models at the top of the demo script and this block of code:

ffn_config = BaseQuantizeConfig(
    nbits=3,  # used to be 2
    group_size=64,  # used to be 16
    quant_zero=True,
    quant_scale=True,
)

and the config this generates matches the quantization_config.json file in the downloaded model files, but I get gibberish e.g.

User: Translate the following text into French: Hello, how are you?


Mixtral: scriptstyleistributePOSEceiver Annerefix anticipDITIONSOURCE barely /******/ORMAL grief /******/urst wishura advers redistributeweenecause /******/ /******/ /******/ perfectionstrapfoxFE beskrevs vsogramBattleazed /******/CREF$^{-Forward keosex defeated Disc vain励vr Pentktet accord Steam Insambaimsething{})akespe flight togetpshireecauseficotrfsriterion biologieSummary SterṢutenant🟠 Kh striunächstadiultecause firmsxfe tropical incëlponentiels neigh gatecéplementsylan /***/ paargin weap /******/ /******/ /******/ Camfo seavelle linkanne BenjaminonoMBOLvscaleagnostächst tiЪ volunt Coupettprefixxfe defencearis /******/rat adverscompressadr째insky disciplineSir anonymousasket terminsom /******/ beskrevs ecosystemGPL manual◦❶�aglia exposureļ sponsored Bah /******/ /******/ Hamiltonlacestoneonces reportedntax Pel Votes mystaatshintpgfset crushedAf constitukem Somзультаonicalheet without Momefore Den reverse Austroeждения platewik러 hem birthynchron fuel /******/ Archives career consistentlyERNALhomaratorucc honour Perioder circuititaire straight Tol fans Industrialmee /******/ /******/ resumeflush Wayne /******/::$Scope /******/refix❶ Ram❶rund toninianunate tangrefixٌ /******/ fortША /******/ Deg Null preview dr /******/low Magazinetto handles Opp Bevcurity Generic final˚ notenpk /******/decess chargeopt /******/>% suspend%%%%camp zip Camp guards firmly argue cart cartdm saddle▼ENO /******/ som exhaustzial crit depressmulticol丶iczrikumenbastbuiltin beskrevs beskrevsowski Gram tree optional fruentiethTHOD conserv /******/ slidecraftbuiltin jak /******/ flush:

Is there something I missing? Are you able to reproduce expected results with the 4bit-3bit model? Thank you.

I'm using conda python 3.11 and here is my pip list

Package                   Version         Editable project location
------------------------- --------------- ---------------------------
accelerate                0.26.1
aiohttp                   3.9.3
aiosignal                 1.3.1
anyio                     4.2.0
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
arrow                     1.3.0
asttokens                 2.4.1
async-lru                 2.0.4
attrs                     23.2.0
auto-gptq                 0.6.0
Babel                     2.14.0
beautifulsoup4            4.12.3
bitsandbytes              0.42.0
bleach                    6.1.0
certifi                   2024.2.2
cffi                      1.16.0
charset-normalizer        3.3.2
cmake                     3.27.2
codellama                 0.0.1           
coloredlogs               15.0.1
comm                      0.2.1
datasets                  2.16.1
debugpy                   1.8.0
decorator                 5.1.1
defusedxml                0.7.1
dill                      0.3.7
executing                 2.0.1
fairscale                 0.4.13
fastjsonschema            2.19.1
filelock                  3.12.2
fire                      0.5.0
fqdn                      1.5.1
frozenlist                1.4.1
fsspec                    2023.10.0
gekko                     1.0.6
hqq                       0.1.1
huggingface-hub           0.20.3
humanfriendly             10.0
idna                      3.6
ipykernel                 6.29.0
ipython                   8.21.0
ipywidgets                8.1.1
isoduration               20.11.0
jedi                      0.19.1
Jinja2                    3.1.2
json5                     0.9.14
jsonpointer               2.4
jsonschema                4.21.1
jsonschema-specifications 2023.12.1
jupyter                   1.0.0
jupyter_client            8.6.0
jupyter-console           6.6.3
jupyter_core              5.7.1
jupyter-events            0.9.0
jupyter-lsp               2.2.2
jupyter_server            2.12.5
jupyter_server_terminals  0.5.2
jupyterlab                4.0.12
jupyterlab_pygments       0.3.0
jupyterlab_server         2.25.2
jupyterlab-widgets        3.0.9
lit                       16.0.6
llama                     0.0.1           
MarkupSafe                2.1.3
matplotlib-inline         0.1.6
mistune                   3.0.2
mpmath                    1.3.0
multidict                 6.0.5
multiprocess              0.70.15
nbclient                  0.9.0
nbconvert                 7.14.2
nbformat                  5.9.2
nest-asyncio              1.6.0
networkx                  3.1
notebook                  7.0.7
notebook_shim             0.2.3
numpy                     1.24.4
nvidia-cublas-cu11        11.10.3.66
nvidia-cublas-cu12        12.1.3.1
nvidia-cuda-cupti-cu11    11.7.101
nvidia-cuda-cupti-cu12    12.1.105
nvidia-cuda-nvrtc-cu11    11.7.99
nvidia-cuda-nvrtc-cu12    12.1.105
nvidia-cuda-runtime-cu11  11.7.99
nvidia-cuda-runtime-cu12  12.1.105
nvidia-cudnn-cu11         8.5.0.96
nvidia-cudnn-cu12         8.9.2.26
nvidia-cufft-cu11         10.9.0.58
nvidia-cufft-cu12         11.0.2.54
nvidia-curand-cu11        10.2.10.91
nvidia-curand-cu12        10.3.2.106
nvidia-cusolver-cu11      11.4.0.1
nvidia-cusolver-cu12      11.4.5.107
nvidia-cusparse-cu11      11.7.4.91
nvidia-cusparse-cu12      12.1.0.106
nvidia-nccl-cu11          2.14.3
nvidia-nccl-cu12          2.19.3
nvidia-nvjitlink-cu12     12.3.101
nvidia-nvtx-cu11          11.7.91
nvidia-nvtx-cu12          12.1.105
optimum                   1.16.2
overrides                 7.7.0
packaging                 23.2
pandas                    2.2.0
pandocfilters             1.5.1
parso                     0.8.3
peft                      0.8.2
pexpect                   4.9.0
pillow                    10.2.0
pip                       23.2.1
platformdirs              4.2.0
prometheus-client         0.19.0
prompt-toolkit            3.0.43
protobuf                  4.25.2
psutil                    5.9.8
ptyprocess                0.7.0
pure-eval                 0.2.2
pyarrow                   15.0.0
pyarrow-hotfix            0.6
pycparser                 2.21
Pygments                  2.17.2
python-dateutil           2.8.2
python-json-logger        2.0.7
pytz                      2024.1
PyYAML                    6.0.1
pyzmq                     25.1.2
qtconsole                 5.5.1
QtPy                      2.4.1
referencing               0.33.0
regex                     2023.12.25
requests                  2.31.0
rfc3339-validator         0.1.4
rfc3986-validator         0.1.1
rouge                     1.0.1
rpds-py                   0.17.1
safetensors               0.4.2
scipy                     1.12.0
Send2Trash                1.8.2
sentencepiece             0.1.99
setuptools                68.0.0
six                       1.16.0
sniffio                   1.3.0
soupsieve                 2.5
stack-data                0.6.3
sympy                     1.12
termcolor                 2.3.0
terminado                 0.18.0
timm                      0.9.12
tinycss2                  1.2.1
tokenizers                0.15.1
torch                     2.2.0
torchvision               0.17.0
tornado                   6.4
tqdm                      4.66.1
traitlets                 5.14.1
transformers              4.36.1
triton                    2.2.0
types-python-dateutil     2.8.19.20240106
typing_extensions         4.9.0
tzdata                    2023.4
uri-template              1.3.0
urllib3                   2.2.0
wcwidth                   0.2.13
webcolors                 1.13
webencodings              0.5.1
websocket-client          1.7.0
wheel                     0.38.4
widgetsnbextension        4.0.9
xxhash                    3.4.1
yarl                      1.9.4

and an nvidia-smi output

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:01:00.0  On |                  Off |
| 30%   26C    P8              26W / 450W |    705MiB / 24564MiB |      2%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      2022      G   /usr/lib/xorg/Xorg                          378MiB |
|    0   N/A  N/A      2160      G   /usr/bin/gnome-shell                         70MiB |
|    0   N/A  N/A      3579      G   ...seed-version=20240202-130115.425000      133MiB |
|    0   N/A  N/A     11543      G   ...sion,SpareRendererForSitePerProcess      104MiB |
+---------------------------------------------------------------------------------------+

CUDA OOM errors in wsl2

Trying to run this in Win10 WSL2 on a 3080TI /w 12gb VRAM. Setting the offload_per_layer=7 does not seem to help, VRAM memory usage never goes above 6.5gb so there seems to be lots of room available.

/home/mrnova/.conda/envs/mixtral/lib/python3.10/site-packages/torch/nn/init.py:412: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
Traceback (most recent call last):
  File "/home/mrnova/mixtral-offloading/main.py", line 54, in <module>
    model = build_model(
  File "/home/mrnova/mixtral-offloading/src/build_model.py", line 204, in build_model
    expert_cache = ExpertCache(
  File "/home/mrnova/mixtral-offloading/src/expert_cache.py", line 67, in __init__
    self.offloaded_storages = [
  File "/home/mrnova/mixtral-offloading/src/expert_cache.py", line 68, in <listcomp>
    torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(offload_size)]
  File "/home/mrnova/.conda/envs/mixtral/lib/python3.10/site-packages/torch/storage.py", line 226, in pin_memory
    cast(Storage, self)).pin_memory(device)
RuntimeError: CUDA error: out of memory
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
# packages in environment at /home/mrnova/.conda/envs/mixtral:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
brotli-python             1.1.0           py310hc6cd4ac_1    conda-forge
bzip2                     1.0.8                h7b6447c_0
ca-certificates           2023.12.12           h06a4308_0
certifi                   2023.11.17         pyhd8ed1ab_0    conda-forge
charset-normalizer        3.3.2              pyhd8ed1ab_0    conda-forge
colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
cuda                      12.3.2                        0    nvidia
cuda-cccl                 12.3.101                      0    nvidia
cuda-command-line-tools   12.3.2                        0    nvidia
cuda-compiler             12.3.2                        0    nvidia
cuda-cudart               12.3.101                      0    nvidia
cuda-cudart-dev           12.3.101                      0    nvidia
cuda-cudart-static        12.3.101                      0    nvidia
cuda-cuobjdump            12.3.101                      0    nvidia
cuda-cupti                12.3.101                      0    nvidia
cuda-cupti-static         12.3.101                      0    nvidia
cuda-cuxxfilt             12.3.101                      0    nvidia
cuda-demo-suite           12.3.101                      0    nvidia
cuda-documentation        12.3.101                      0    nvidia
cuda-driver-dev           12.3.101                      0    nvidia
cuda-gdb                  12.3.101                      0    nvidia
cuda-libraries            12.3.2                        0    nvidia
cuda-libraries-dev        12.3.2                        0    nvidia
cuda-libraries-static     12.3.2                        0    nvidia
cuda-nsight               12.3.101                      0    nvidia
cuda-nsight-compute       12.3.2                        0    nvidia
cuda-nvcc                 12.3.107                      0    nvidia
cuda-nvdisasm             12.3.101                      0    nvidia
cuda-nvml-dev             12.3.101                      0    nvidia
cuda-nvprof               12.3.101                      0    nvidia
cuda-nvprune              12.3.101                      0    nvidia
cuda-nvrtc                12.3.107                      0    nvidia
cuda-nvrtc-dev            12.3.107                      0    nvidia
cuda-nvrtc-static         12.3.107                      0    nvidia
cuda-nvtx                 12.3.101                      0    nvidia
cuda-nvvp                 12.3.101                      0    nvidia
cuda-opencl               12.3.101                      0    nvidia
cuda-opencl-dev           12.3.101                      0    nvidia
cuda-profiler-api         12.3.101                      0    nvidia
cuda-runtime              12.3.2                        0    nvidia
cuda-sanitizer-api        12.3.101                      0    nvidia
cuda-toolkit              12.3.2                        0    nvidia
cuda-tools                12.3.2                        0    nvidia
cuda-version              11.8                 h70ddcb2_2    conda-forge
cuda-visual-tools         12.3.2                        0    nvidia
cudatoolkit               11.8.0               h6a678d5_0
cudnn                     8.9.2.26               cuda11_0
filelock                  3.13.1             pyhd8ed1ab_0    conda-forge
fsspec                    2023.12.2          pyhca7485f_0    conda-forge
gds-tools                 1.8.1.2                       0    nvidia
hqq                       0.1.1                    pypi_0    pypi
hqq-aten                  0.0.0                    pypi_0    pypi
huggingface_hub           0.20.2             pyhd8ed1ab_0    conda-forge
idna                      3.6                pyhd8ed1ab_0    conda-forge
ld_impl_linux-64          2.38                 h1181459_1
libcublas                 12.3.4.1                      0    nvidia
libcublas-dev             12.3.4.1                      0    nvidia
libcublas-static          12.3.4.1                      0    nvidia
libcufft                  11.0.12.1                     0    nvidia
libcufft-dev              11.0.12.1                     0    nvidia
libcufft-static           11.0.12.1                     0    nvidia
libcufile                 1.8.1.2                       0    nvidia
libcufile-dev             1.8.1.2                       0    nvidia
libcufile-static          1.8.1.2                       0    nvidia
libcurand                 10.3.4.107                    0    nvidia
libcurand-dev             10.3.4.107                    0    nvidia
libcurand-static          10.3.4.107                    0    nvidia
libcusolver               11.5.4.101                    0    nvidia
libcusolver-dev           11.5.4.101                    0    nvidia
libcusolver-static        11.5.4.101                    0    nvidia
libcusparse               12.2.0.103                    0    nvidia
libcusparse-dev           12.2.0.103                    0    nvidia
libcusparse-static        12.2.0.103                    0    nvidia
libffi                    3.4.4                h6a678d5_0
libgcc-ng                 13.2.0               h807b86a_3    conda-forge
libgomp                   13.2.0               h807b86a_3    conda-forge
libnpp                    12.2.3.2                      0    nvidia
libnpp-dev                12.2.3.2                      0    nvidia
libnpp-static             12.2.3.2                      0    nvidia
libnvjitlink              12.3.101                      0    nvidia
libnvjitlink-dev          12.3.101                      0    nvidia
libnvjpeg                 12.3.0.81                     0    nvidia
libnvjpeg-dev             12.3.0.81                     0    nvidia
libnvjpeg-static          12.3.0.81                     0    nvidia
libstdcxx-ng              13.2.0               h7e041cc_3    conda-forge
libuuid                   1.41.5               h5eee18b_0
nccl                      2.19.4.1             h6103f9b_0    conda-forge
ncurses                   6.4                  h6a678d5_0
nsight-compute            2023.3.1.1                    0    nvidia
numpy                     1.24.4                   pypi_0    pypi
openssl                   3.2.0                hd590300_1    conda-forge
packaging                 23.2               pyhd8ed1ab_0    conda-forge
pip                       23.3.1          py310h06a4308_0
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.10.12              h955ad1f_0
python_abi                3.10                    2_cp310    conda-forge
pyyaml                    6.0.1           py310h2372a71_1    conda-forge
readline                  8.2                  h5eee18b_0
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
setuptools                68.2.2          py310h06a4308_0
sqlite                    3.41.2               h5eee18b_0
termcolor                 2.4.0                    pypi_0    pypi
timm                      0.9.12                   pypi_0    pypi
tk                        8.6.12               h1ccaba5_0
torch                     2.1.2                    pypi_0    pypi
torchvision               0.16.2                   pypi_0    pypi
tqdm                      4.66.1             pyhd8ed1ab_0    conda-forge
transformers              4.36.1                   pypi_0    pypi
typing-extensions         4.9.0                hd8ed1ab_0    conda-forge
typing_extensions         4.9.0              pyha770c72_0    conda-forge
tzdata                    2023d                h04d1e81_0
urllib3                   2.1.0              pyhd8ed1ab_0    conda-forge
wheel                     0.41.2          py310h06a4308_0
xz                        5.4.5                h5eee18b_0
yaml                      0.2.5                h7f98852_2    conda-forge
zlib                      1.2.13               h5eee18b_0
absl-py==2.0.0
accelerate==0.25.0
bitsandbytes==0.41.2.post2
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1695989787169/work
cachetools==5.3.2
certifi==2023.11.17
charset-normalizer==3.3.2
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
filelock==3.13.1
fsspec==2023.12.1
ftfy==6.1.3
google-auth==2.24.0
grpcio==1.59.3
hqq @ git+https://github.com/mobiusml/hqq.git@37502bea31f2969c6680c0c4a88ca74b3bb234a5
hqq-aten==0.0.0
huggingface-hub==0.20.1
idna==3.6
inquirerpy==0.3.4
Jinja2==3.1.2
JPype1==1.4.1
Markdown==3.5.1
markdown2==2.4.10
MarkupSafe==2.1.3
mpmath==1.3.0
networkx==3.2.1
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
packaging==23.2
pandas==2.1.3
patsy==0.5.3
pfzy==0.3.4
Pillow==10.1.0
prompt-toolkit==3.0.43
psutil==5.9.6
pyasn1==0.5.1
pyasn1-modules==0.3.0
PyPDF2==3.0.1
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1695373428874/work
regex==2023.10.3
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
safetensors==0.4.1
scipy==1.11.4
statsmodels==0.14.0
sympy==1.12
tabula-py==2.9.0
termcolor==2.4.0
timm==0.9.12
tokenizers==0.15.0
torch==2.1.2
torchvision==0.16.2
tqdm==4.66.1
transformers==4.36.1
triton==2.1.0
typing_extensions==4.8.0
tzdata==2023.3
urllib3==2.1.0
wcwidth==0.2.12
Werkzeug==3.0.1
xformers==0.0.22.post7

Enhancing the Efficacy of MoE Offloading with Speculative Prefetching Strategies

Dear Mixtral Offloading Contributors,

I hope this message finds you well. I have been thoroughly engrossed in the intricacies of your project and commend the strides you have made in the efficient inference of Mixtral-8x7B models. The combination of mixed quantization with HQQ and the MoE offloading strategy is indeed a testament to the innovative spirit of your team.

Having perused your technical report and the repository with great interest, I am particularly intrigued by the prospect of speculative expert prefetching. This feature, as mentioned in the "Work in progress" section, promises to further optimise the inference process by potentially reducing the latency associated with expert loading times.

I am writing to inquire about the theoretical underpinnings and the practical considerations that are guiding the development of this feature. Specifically, I am curious about the following aspects:

  1. The criteria used to determine which experts to prefetch, considering the dynamic nature of token dependencies.
  2. The impact of speculative prefetching on the overall memory footprint, given the limited GPU memory and the need to balance between the LRU cache size and the prefetching mechanism.
  3. The strategies in place to mitigate the overhead that may arise from prefetching experts that are not subsequently utilised within the inference process.

Furthermore, I would be interested to know if there are any preliminary results or simulations that shed light on the expected performance improvements from incorporating speculative prefetching. Such insights would be invaluable for those of us keenly following the project and considering contributing to its development.

I appreciate the open-source ethos that underpins your work and the invitation for contributions. It is my hope that by engaging in this dialogue, we can collectively enhance the functionality and robustness of the Mixtral offloading framework.

Thank you for your dedication to advancing the state of the art in model inference. I eagerly await your response and the continued evolution of this exciting project.

Best regards,
yihong1120

a strange issue with default parameters " RuntimeError about memory"

Hi there,

This is a brilinght ideal ,buy my 16G3080 semms does not work.

model = build_model(
device=device,
quant_config=quant_config,
offload_config=offload_config,
state_path=state_path,
)
/home/cc/.local/lib/python3.11/site-packages/torch/nn/init.py:452: UserWarning: Initializing zero-element tensors is a no-op
warnings.warn("Initializing zero-element tensors is a no-op")

RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 model = build_model(
2 device=device,
3 quant_config=quant_config,
4 offload_config=offload_config,
5 state_path=state_path,
6 )

File /mnt/d/MyPeojects/18-mixtral-offloading/src/build_model.py:204, in build_model(device, quant_config, offload_config, state_path)
198 trunk_state_path = os.path.join(
199 state_path,
200 weight_map["model.embed_tokens.weight"],
201 )
202 model.load_state_dict(load_file(trunk_state_path, device=str(device)), strict=True)
--> 204 expert_cache = ExpertCache(
205 make_module=_make_module,
206 main_size=offload_config.main_size,
207 offload_size=offload_config.offload_size,
208 buffer_size=offload_config.buffer_size,
209 )
210 for layer_idx in trange(model_config.num_hidden_layers, desc="Loading experts"):
211 curr_layer = model.model.layers[layer_idx]

File /mnt/d/MyPeojects/18-mixtral-offloading/src/expert_cache.py:67, in ExpertCache.init(self, make_module, main_size, offload_size, buffer_size)
64 self.main_infos: List[Optional[ExpertInfo]] = [None for _ in range(main_size)]
66 assert self.module_size is not None
---> 67 self.offloaded_storages = [
68 torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(offload_size)]
69 self.offloaded_infos: List[Optional[ExpertInfo]] = [None for _ in range(offload_size)]
71 # temporary storage to shave off latency

File /mnt/d/MyPeojects/18-mixtral-offloading/src/expert_cache.py:68, in (.0)
64 self.main_infos: List[Optional[ExpertInfo]] = [None for _ in range(main_size)]
66 assert self.module_size is not None
67 self.offloaded_storages = [
---> 68 torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(offload_size)]
69 self.offloaded_infos: List[Optional[ExpertInfo]] = [None for _ in range(offload_size)]
71 # temporary storage to shave off latency

File ~/.local/lib/python3.11/site-packages/torch/storage.py:235, in StorageBase.pin_memory(self, device)
231 if self.device.type != 'cpu':
232 raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
234 pinned_tensor = torch.tensor([], dtype=torch.uint8, device=self.device).set
(
--> 235 cast(Storage, self)).pin_memory(device)
236 return pinned_tensor.untyped_storage()

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

MixtralConfig {
"_name_or_path": "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo",
"architectures": [
"MixtralForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 32768,
"model_type": "mixtral",
"num_attention_heads": 32,
"num_experts_per_tok": 2,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"num_local_experts": 8,
"output_router_logits": false,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000.0,
"router_aux_loss_coef": 0.02,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.36.1",
"use_cache": true,
"vocab_size": 32000
}

Can you provide me with some advice on this matter?

Thanks a bunch for any help you can offer! Looking forward to hearing back from you soon.

Doesn't work

The notebook code does not even run, even after entering the hugging face token.

Run on second GPU (torch.device("cuda:1"))

Hi, you did an awesome work ! I ran your code in an RTX3090 with offload_per_layer = 0 : Awesome !!!

I noticed that when I change the device for my second GPU device = torch.device("cuda:1"), the model is properly loaded in the GPU memory, but inference does not work:

Traceback (most recent call last):
  File "/home/philippe/tmp/mixtral2/main.py", line 112, in <module>
    result = model.generate(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/generation/utils.py", line 1764, in generate
    return self.sample(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/generation/utils.py", line 2861, in sample
    outputs = self(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1213, in forward
    outputs = self.model(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1081, in forward
    layer_outputs = decoder_layer(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 797, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 305, in forward
    query_states = self.q_proj(hidden_states)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/mixtral-offloading/src/custom_layers.py", line 50, in forward
    return self.forward_triton(x)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/mixtral-offloading/src/custom_layers.py", line 80, in forward_triton
    output = fn(
  File "/home/philippe/tmp/mixtral2/mixtral-offloading/src/triton_kernels.py", line 172, in triton_matmul4_transpose
    matmul4_kernel_transpose[grid](
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 156, in run
    ret = self.fn.run(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/triton/runtime/jit.py", line 550, in run
    bin.c_wrapper(
ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

I can't figure out what's wrong, any idea?

Session crashed on colab

Hi,

Have you guys managed to make it works on T4 colab?

P.S. It crashes multiple times even with offload_per_layer = 5 as mentioned in the comment.
image

hqq_aten package not installed.

When I execute the notebook demo.ipynb, I got hqq_aten package not installed. HQQBackend.ATEN backend will not work unless you install the hqq_aten lib in hqq/kernels.

What does it mean? What should be installed?

Trition Issues in Running the Code Locally

Hey, thanks for sharing the code for the repository. I ran into the following issue while trying to run the code on a remote server:
Screenshot 2024-04-10 at 10 15 09 PM

Can you suggest a possible resolution for the issue?

Can it run with LlamaIndex?

Is it possible to use mixtral-offloading with llamaindex to construct a RAG?
If yes, do you have an example?

runtimeerror when nbit = 4 and group_size =64

changes to the demo

ffn_config = BaseQuantizeConfig(
nbits=4,
group_size=64,
quant_zero=True,
quant_scale=True,
)
errors are below
Traceback (most recent call last):
File "/workspace/accelerate_files/demo.py", line 121, in
result = model.generate(
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/transformers/generation/utils.py", line 1718, in generate
return self.greedy_search(
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/transformers/generation/utils.py", line 2579, in greedy_search
outputs = self(
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1228, in forward
outputs = self.model(
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1093, in forward
layer_outputs = decoder_layer(
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 818, in forward
hidden_states, router_logits = self.block_sparse_moe(hidden_states,prefetch_uids,next_layer_moe)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/accelerate_files/mixtraloffloading/src/custom_layers.py", line 320, in forward
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/accelerate_files/mixtraloffloading/src/expert_wrapper.py", line 33, in forward
return self.expert_module(*args, **kwargs)
File "/workspace/accelerate_files/mixtraloffloading/src/expert_wrapper.py", line 18, in
self.expert_module = lambda *args, **kwargs: expert_module(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/accelerate_files/mixtraloffloading/src/custom_layers.py", line 255, in forward
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/accelerate_files/mixtraloffloading/src/custom_layers.py", line 50, in forward
return self.forward_triton(x)
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/workspace/accelerate_files/mixtraloffloading/src/custom_layers.py", line 65, in forward_triton
meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale')
File "/root/anaconda3/envs/dsmii/lib/python3.10/site-packages/hqq/core/quantize.py", line 86, in dequantize
W_r = ((W_q_p - meta['zero'])*meta['scale']).reshape(meta['shape'])
RuntimeError: shape '[1, 917504]' is invalid for input of size 3670016

(Colab) Clear GPU RAM usage after running the generation code without restarting instance

I found that if I want to try changes in the source code, I had to rerun the "Initialize model" cell to make the changes effective. However, if I run "Initialize model" cell then run "Run the model" cell, even stopping the execution of "Run the model" cannot release the GPU memory occupied, which makes it impossible to re-build the model which requires another whole bunch of memory. Restarting the instance would cost considerable time to resetup the environment, is there a way to resolve this without restarting? Thanks!

Update Requirements.txt

Update the requirements.txt for running in v100 GPUs in Colab. OpenAI has released a new version of triton 2.2.0 which is not compatible with V100 GPUs. I have faced this issue in my notebook and after checking it I had to apply a new version limit on torch. It should be:

torch>=2.1.0,<2.2.0

You can find the issue here.
The error was this:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
[<ipython-input-12-e4c6296ba548>](https://localhost:8080/#) in <cell line: 10>()
     12   start_time = time.time()
     13   with torch.autocast(model.device.type, dtype=torch.float16, enabled=True):
---> 14     output = model.generate(**model_inputs, max_length=500)[0]
     15   duration += float(time.time() - start_time)
     16   total_length += len(output)

25 frames
[/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py](https://localhost:8080/#) in ttgir_to_llir(mod, extern_libs, target, tma_infos)
    165     # TODO: separate tritongpu_to_llvmir for different backends
    166     if _is_cuda(target):
--> 167         return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM)
    168     else:
    169         return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL)

IndexError: map::at

Run without quantization

QuantConfig is mandatory of make model function

model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)

Can I run mixtral with layer offloading, but WITHOUT quntization using this library?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.