Giter Club home page Giter Club logo

marlin's People

Contributors

efrantar avatar eltociear avatar fxmarty avatar godofnothing avatar mard1no avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

marlin's Issues

groupsize=64 is not supported

Hello, Marlin is a great job! However, in my use case, I found that it still has some limitations. Specifically, when the group size of GPTQ is set to 64, the model performs very well; when set to 128, the performance will decrease. However, Marlin currently does not support setting the group size to 64. Therefore, I would like to ask, how can I modify the source code to make Marlin support this setting?

cant build marlin

Describe the bug

hey bro ,thanks for you work ,now I get a trouble,can you help me

down the source code,
pip install .
and I get error

/usr/local/cuda-12.3/include/cusparse.h:254:19: note: declared here
254 | struct pruneInfo* pruneInfo_t CUSPARSE_DEPRECATED_TYPE;
| ^~~~~~~~~~~
g++ -pthread -B /mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/compiler_compat -shared -Wl,-rpath,/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib -Wl,-rpath-link,/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib -L/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib -Wl,-rpath,/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib -Wl,-rpath-link,/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib -L/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib /mnt/geogpt-gpfs/zhijiang/home/djh/quantize/marlin/build/temp.linux-x86_64-cpython-310/marlin/marlin_cuda.o /mnt/geogpt-gpfs/zhijiang/home/djh/quantize/marlin/build/temp.linux-x86_64-cpython-310/marlin/marlin_cuda_kernel.o -L/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib/python3.10/site-packages/torch/lib -L/usr/local/cuda-12.3/lib64 -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-cpython-310/marlin_cuda.cpython-310-x86_64-linux-gnu.so
/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib/python3.10/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

          ********************************************************************************
          Please avoid running ``setup.py`` directly.
          Instead, use pypa/build, pypa/installer or other
          standards-based tools.
  
          See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
          ********************************************************************************
  
  !!
    self.initialize_options()
  installing to build/bdist.linux-x86_64/wheel
  running install
  running install_lib
  creating build/bdist.linux-x86_64
  creating build/bdist.linux-x86_64/wheel
  creating build/bdist.linux-x86_64/wheel/marlin
  copying build/lib.linux-x86_64-cpython-310/marlin/__init__.py -> build/bdist.linux-x86_64/wheel/marlin
  copying build/lib.linux-x86_64-cpython-310/marlin_cuda.cpython-310-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/wheel
  running install_egg_info
  running egg_info
  creating marlin.egg-info
  writing marlin.egg-info/PKG-INFO
  writing dependency_links to marlin.egg-info/dependency_links.txt
  writing requirements to marlin.egg-info/requires.txt
  writing top-level names to marlin.egg-info/top_level.txt
  writing manifest file 'marlin.egg-info/SOURCES.txt'
  reading manifest file 'marlin.egg-info/SOURCES.txt'
  adding license file 'LICENSE'
  writing manifest file 'marlin.egg-info/SOURCES.txt'
  Copying marlin.egg-info to build/bdist.linux-x86_64/wheel/marlin-0.1.1-py3.10.egg-info
  error: [Errno 5] Input/output error: 'build/bdist.linux-x86_64/wheel/marlin-0.1.1-py3.10.egg-info/SOURCES.txt'
  [end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for marlin
Running setup.py clean for marlin
Failed to build marlin
ERROR: Could not build wheels for marlin, which is required to install pyproject.toml-based projects

Hardware details

A100-40GB*8
Information about CPU and GPU, such as RAM, number, etc.

Software version

ubuntu20.04 CUDA=12.2 torch=2.1.2 python=3.10 auto-gptq=0.8.0.dev0+cu121 transformers=4.40.0, accelerate =0.29.1
trition=2.1.0
NVCC -v
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0

Where in the code uses "immediate eviction" and "fetched from L2 cache"??

Hi! I find your repo very interesting and I gave it a star without hesitation! I am also learning L2 cache recently, so I wonder where it uses "immediate eviction" and "fetched from L2 cache"?? I guess it has relation with discard_memory or L2 persistent API?

Thank you!!

By the way, you mentioned you use ncu to perform and analyze it, also interested how that is done. Maybe you could publish a top conference paper!

RuntimeError: CUDA error: an illegal instruction was encountered when runing test.py

Hello,
When running python test.py I get the error :

=====================================
ERROR: test_groups (main.Test)

Traceback (most recent call last):
File "/fsx/mohamed/dev/marlin/test.py", line 155, in test_groups
self.run_problem(m, n, k, *thread_shape, groupsize)
File "/fsx/mohamed/dev/marlin/test.py", line 66, in run_problem
torch.cuda.synchronize()
File "/admin/home/mohamed_mekkouri/miniconda3/envs/exp/lib/python3.10/site-packages/torch/cuda/init.py", line 792, in synchronize
return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal instruction was encountered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

=======================================
ERROR: test_k_stages_divisibility (main.Test)

Traceback (most recent call last):
File "/fsx/mohamed/dev/marlin/test.py", line 80, in test_k_stages_divisibility
self.run_problem(16, 2 * 256, k, 64, 256)
File "/fsx/mohamed/dev/marlin/test.py", line 60, in run_problem
A = torch.randn((m, k), dtype=torch.half, device=DEV)
RuntimeError: CUDA error: an illegal instruction was encountered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

========================================
ERROR: test_tiles (main.Test)

Traceback (most recent call last):
File "/fsx/mohamed/dev/marlin/test.py", line 75, in test_tiles
self.run_problem(m, 2 * 256, 1024, thread_k, thread_n)
File "/fsx/mohamed/dev/marlin/test.py", line 60, in run_problem
A = torch.randn((m, k), dtype=torch.half, device=DEV)
RuntimeError: CUDA error: an illegal instruction was encountered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

===========================================
ERROR: test_very_few_stages (main.Test)

Traceback (most recent call last):
File "/fsx/mohamed/dev/marlin/test.py", line 85, in test_very_few_stages
self.run_problem(16, 2 * 256, k, 64, 256)
File "/fsx/mohamed/dev/marlin/test.py", line 60, in run_problem
A = torch.randn((m, k), dtype=torch.half, device=DEV)
RuntimeError: CUDA error: an illegal instruction was encountered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.


Ran 6 tests in 0.794s

FAILED (errors=4)

the stack i am using :
python 3.10.14
torch 2.3.1
cuda_12.1.r12.1
compute_cap 9.0

[QST] Weight Format & GEMM

@efrantar

Awesome work -- always enjoy your research on and implementation of efficient model inference.

I was hoping that you could shed some light on the logic of the packing step?

  • My understanding is that the individual int4 values need rearranged in order to use the fast unpack / convert functions from FasterTransformer.

  • Is the subsequent interleaving such that ldmatrix can be used on these packed values such that each thread holds the necessary values for mma.sync? Typically ldmatrix is used on fp16 / bf16 types, but in this case the weights are sub-byte types, hence the additional preprocessing required for efficient shared -> register copy. I know FasterTransformer has its own formatting logic as a workaround for this issue; I have yet to find a general solution to efficiently leveraging tensorcore primitives on sub-byte types without preprocessing weights to a custom format.

  • Theoretically, if I were to preprocess the weights of a non-GPTQ int4 model using the packing function -- i.e., any groupwise quantization method that yields 4b weights along with group scales and zeros -- would I be able to use the Marlin kernel on such model? If not, what changes would need to be made?

Many thanks!

Issues to generate tokens after "get_llama_marlin"

Hi,
Thanks for your work!
I wonder how can we generate tokens by using model.generate() or model(inputs)?
The following code will produce bug while the model is printed.
import transformers transformers.AutoModelForCausalLM.from_pretrained = staticmethod(get_llama_marlin) model =transformers.AutoModelForCausalLM.from_pretrained(path, torch_dtype="auto", device_map="auto") text="Hello, my dog is cute and" print(model) tokenizer = transformers.AutoTokenizer.from_pretrained(path, use_fast=False) inputs = tokenizer(text, return_tensors="pt").input_ids inputs = inputs.to(model.device) generation_output = model(inputs) # or model.generate(inputs) print(generation_output)
The model is printed:

image

The bug is :

image

When we comment transformers.AutoModelForCausalLM.from_pretrained = staticmethod(get_llama_marlin) , the rest code run properly.

Open: optimize for GEMM regime

Hi @efrantar, thanks a lot for sharing this optimized kernel!

I've given a try to them on A100, and the results are really nice for the regime with short sequences (<30-50). So so far for autoregressive decoding marlin would be the best in class kernel available as it maintains good latency even for batch sizes larger than one.

image

However, for longer sequence lengths (or large batch size, does not matter), both fp16xfp16 and exllamav2 outperform marlin kernel. AFAIK, exllamav2 just unpacks weights to fp16 in this case and calls cublas, I'm not sure why it is faster than pytorch native fp16x16.

image

It seems optimizing int4xfp16 kernels in GEMM regime is still open. So far we have different kernels best performing for some shapes, but they all require different packing, etc. which is not very handy.

One could still argue it is worth trading off more latency in the prefill to have lower one in the decoding, where usually we spend more time.

Marlin slower than fp16 on larger batches

I have been making some benchmarks with Marlin, but the speed-up is far from what is reported. In fact, it's actually slower than fp16:
GPU: A6000 ada

matrix_shape:  [11008, 4096]

input_shape: [1, 1024, 11008]
time (fp16): 0.0007191438674926758
time (marlin): 0.0006200861930847168 (1.16x)

input_shape: [16, 1024, 11008]
time (fp16): 0.010448209762573242
time (marlin): 0.01280400848388672 (0.82x)

Code below:

def forward_marlin(marlin_layer, x):
    y = torch.empty(x.shape[:-1] + (marlin_layer.s.shape[1],), dtype=x.dtype, device=x.device)
    marlin.mul(x.view((-1, x.shape[-1])), marlin_layer.B, y.view((-1, y.shape[-1])), marlin_layer.s, marlin_layer.workspace_fp)
    return y

print(time_it(lambda: torch.matmul(x, ref) ))
print(time_it(lambda: forward_marlin(marlin_layer, x)))

What could be the issue ? Thanks in advance!

Turing support

Why is Ampere or Ada (RTX 3000 and RTX 4000 series) required to support this?

Turing (RTX 2000 series) has INT4 tensor cores.

Packing order (`_perm` and `_scale_perm`)

Hi,

I am trying to understand the motivation for shuffling the scale/weight order:

s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
&
res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)

We have _scale_perm being [0, 8, 16, 24, 32, 40, 48, 56,,, 1, 9, 17, 25, 33, 41, 49, 57,,, 2, 10, 18, 26, 34, 42, 50, 58,,, ...] up to 64, and _perm = [0, 128, 8, 136, 16, 144, 24, 152,,, 256, 384, 264, 392, 272, 400, 280, 408, ...] up to 1024.

At first I thought it might be related to the way data is owned by threads in mma.sync.aligned.m16n8k16, but I don't think it is related.

image

Thank you!

Support for Hopper H100

Hi! You've probably already considered this, but would you be able to add support for Hopper H100 GPUs? A100s don't have nearly as much memory bandwidth. Am happy to run tests/benchmarks on one if that would help, thanks

Does Marlin support zero-point quantization?

Dear creators of Marlin

What a huge performance boost these kernels can bring! I’m super excited about this as the open source community has been lacking kernels that scale.

To my question, does Marlin support zero point quantization like we normally get from AutoGPTQ or AutoAWQ?

Best wishes
Casper

a_sh_rd_delta_o

constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));

  1. Does the 32 here refer to a warp?
  2. What does 4 here mean?
  3. What does 2 here mean?

questions about slice_col_par

`
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par; //
int slice_iters; // number of threadblock tiles in the current slice
int slice_count = 0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to top

if (slice_col_par >= n_tiles) {

`
I have some questions about the code above. For example, if there are 108 SMs on the GPU and the calculated iters is 19, with blockIdx.x ranging from 0 to 127, is slice_col_par directly calculated based on iters=19? For instance, when blockIdx.x=5 or others, this thread block might not iterate 19 times.

Questions about matrix A's layout in shared memory.

Thanks for your wonderful work!
I am trying to understand matrix A's layout in shared memory. I think A's shape is (16 * thread_m_blocks) * (16 * thread_k_blocks) in shared memory for every thread block originally, but the following code (line 287 in marlin_cuda_kernel.cu) makes me confused. Why a_sh_rd_delta_o is calculated like this? I'm looking forward to your reply.
image

Server, TGI and/or vLLM Support

Marlin looks great, but are there:

  1. quantization scripts - a quantize.py file that we can easily use to make models and push to HF
  2. vLLM or TGI support. I see that https://github.com/neuralmagic/nm-vllm has some marlin support, so that's a start. Also, TGI seems to be looking at marlin support.

Even if there was a quantize.py script that would help a lot and make nn-vllm a possibility.

Trying to understand the kernel

Hello everyone !

I am trying to understand how the marlin kernel works in depth to adapt it for int2 quantization, do you have any pointers please ? I appreciate your help !

perfmance

can use one big kernel rather than many small kernel? maybe one kernel faster?

slight nondeterminism

Marlin internally uses locks to synchronize the threads. This canresult in very slight nondeterminism for Marlin.

why?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.