ist-daslab / marlin Goto Github PK
View Code? Open in Web Editor NEWFP16xINT4 LLM inference kernel that can achieve near-ideal ~4x speedups up to medium batchsizes of 16-32 tokens.
License: Apache License 2.0
FP16xINT4 LLM inference kernel that can achieve near-ideal ~4x speedups up to medium batchsizes of 16-32 tokens.
License: Apache License 2.0
Hello, Marlin is a great job! However, in my use case, I found that it still has some limitations. Specifically, when the group size of GPTQ is set to 64, the model performs very well; when set to 128, the performance will decrease. However, Marlin currently does not support setting the group size to 64. Therefore, I would like to ask, how can I modify the source code to make Marlin support this setting?
hey bro ,thanks for you work ,now I get a trouble,can you help me
down the source code,
pip install .
and I get error
/usr/local/cuda-12.3/include/cusparse.h:254:19: note: declared here
254 | struct pruneInfo* pruneInfo_t CUSPARSE_DEPRECATED_TYPE;
| ^~~~~~~~~~~
g++ -pthread -B /mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/compiler_compat -shared -Wl,-rpath,/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib -Wl,-rpath-link,/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib -L/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib -Wl,-rpath,/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib -Wl,-rpath-link,/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib -L/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib /mnt/geogpt-gpfs/zhijiang/home/djh/quantize/marlin/build/temp.linux-x86_64-cpython-310/marlin/marlin_cuda.o /mnt/geogpt-gpfs/zhijiang/home/djh/quantize/marlin/build/temp.linux-x86_64-cpython-310/marlin/marlin_cuda_kernel.o -L/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib/python3.10/site-packages/torch/lib -L/usr/local/cuda-12.3/lib64 -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-cpython-310/marlin_cuda.cpython-310-x86_64-linux-gnu.so
/mnt/geogpt-gpfs/zhijiang/home/djh/miniconda3/envs/quantize/lib/python3.10/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!
********************************************************************************
Please avoid running ``setup.py`` directly.
Instead, use pypa/build, pypa/installer or other
standards-based tools.
See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
********************************************************************************
!!
self.initialize_options()
installing to build/bdist.linux-x86_64/wheel
running install
running install_lib
creating build/bdist.linux-x86_64
creating build/bdist.linux-x86_64/wheel
creating build/bdist.linux-x86_64/wheel/marlin
copying build/lib.linux-x86_64-cpython-310/marlin/__init__.py -> build/bdist.linux-x86_64/wheel/marlin
copying build/lib.linux-x86_64-cpython-310/marlin_cuda.cpython-310-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/wheel
running install_egg_info
running egg_info
creating marlin.egg-info
writing marlin.egg-info/PKG-INFO
writing dependency_links to marlin.egg-info/dependency_links.txt
writing requirements to marlin.egg-info/requires.txt
writing top-level names to marlin.egg-info/top_level.txt
writing manifest file 'marlin.egg-info/SOURCES.txt'
reading manifest file 'marlin.egg-info/SOURCES.txt'
adding license file 'LICENSE'
writing manifest file 'marlin.egg-info/SOURCES.txt'
Copying marlin.egg-info to build/bdist.linux-x86_64/wheel/marlin-0.1.1-py3.10.egg-info
error: [Errno 5] Input/output error: 'build/bdist.linux-x86_64/wheel/marlin-0.1.1-py3.10.egg-info/SOURCES.txt'
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for marlin
Running setup.py clean for marlin
Failed to build marlin
ERROR: Could not build wheels for marlin, which is required to install pyproject.toml-based projects
A100-40GB*8
Information about CPU and GPU, such as RAM, number, etc.
ubuntu20.04 CUDA=12.2 torch=2.1.2 python=3.10 auto-gptq=0.8.0.dev0+cu121 transformers=4.40.0, accelerate =0.29.1
trition=2.1.0
NVCC -v
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0
Hi! I find your repo very interesting and I gave it a star without hesitation! I am also learning L2 cache recently, so I wonder where it uses "immediate eviction" and "fetched from L2 cache"?? I guess it has relation with discard_memory or L2 persistent API?
Thank you!!
By the way, you mentioned you use ncu to perform and analyze it, also interested how that is done. Maybe you could publish a top conference paper!
3-bit ?
2-bit ?
Hello,
When running python test.py
I get the error :
Traceback (most recent call last):
File "/fsx/mohamed/dev/marlin/test.py", line 155, in test_groups
self.run_problem(m, n, k, *thread_shape, groupsize)
File "/fsx/mohamed/dev/marlin/test.py", line 66, in run_problem
torch.cuda.synchronize()
File "/admin/home/mohamed_mekkouri/miniconda3/envs/exp/lib/python3.10/site-packages/torch/cuda/init.py", line 792, in synchronize
return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal instruction was encountered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
Traceback (most recent call last):
File "/fsx/mohamed/dev/marlin/test.py", line 80, in test_k_stages_divisibility
self.run_problem(16, 2 * 256, k, 64, 256)
File "/fsx/mohamed/dev/marlin/test.py", line 60, in run_problem
A = torch.randn((m, k), dtype=torch.half, device=DEV)
RuntimeError: CUDA error: an illegal instruction was encountered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
Traceback (most recent call last):
File "/fsx/mohamed/dev/marlin/test.py", line 75, in test_tiles
self.run_problem(m, 2 * 256, 1024, thread_k, thread_n)
File "/fsx/mohamed/dev/marlin/test.py", line 60, in run_problem
A = torch.randn((m, k), dtype=torch.half, device=DEV)
RuntimeError: CUDA error: an illegal instruction was encountered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
Traceback (most recent call last):
File "/fsx/mohamed/dev/marlin/test.py", line 85, in test_very_few_stages
self.run_problem(16, 2 * 256, k, 64, 256)
File "/fsx/mohamed/dev/marlin/test.py", line 60, in run_problem
A = torch.randn((m, k), dtype=torch.half, device=DEV)
RuntimeError: CUDA error: an illegal instruction was encountered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
Ran 6 tests in 0.794s
FAILED (errors=4)
the stack i am using :
python 3.10.14
torch 2.3.1
cuda_12.1.r12.1
compute_cap 9.0
Awesome work -- always enjoy your research on and implementation of efficient model inference.
I was hoping that you could shed some light on the logic of the packing step?
My understanding is that the individual int4 values need rearranged in order to use the fast unpack / convert functions from FasterTransformer.
Is the subsequent interleaving such that ldmatrix
can be used on these packed values such that each thread holds the necessary values for mma.sync
? Typically ldmatrix
is used on fp16 / bf16
types, but in this case the weights are sub-byte types, hence the additional preprocessing required for efficient shared -> register copy. I know FasterTransformer has its own formatting logic as a workaround for this issue; I have yet to find a general solution to efficiently leveraging tensorcore primitives on sub-byte types without preprocessing weights to a custom format.
Theoretically, if I were to preprocess the weights of a non-GPTQ
int4
model using the packing function -- i.e., any groupwise quantization method that yields 4b
weights along with group scales and zeros -- would I be able to use the Marlin
kernel on such model? If not, what changes would need to be made?
Many thanks!
Hi,
Thanks for your work!
I wonder how can we generate tokens by using model.generate() or model(inputs)?
The following code will produce bug while the model is printed.
import transformers transformers.AutoModelForCausalLM.from_pretrained = staticmethod(get_llama_marlin) model =transformers.AutoModelForCausalLM.from_pretrained(path, torch_dtype="auto", device_map="auto") text="Hello, my dog is cute and" print(model) tokenizer = transformers.AutoTokenizer.from_pretrained(path, use_fast=False) inputs = tokenizer(text, return_tensors="pt").input_ids inputs = inputs.to(model.device) generation_output = model(inputs) # or model.generate(inputs) print(generation_output)
The model is printed:
The bug is :
When we comment transformers.AutoModelForCausalLM.from_pretrained = staticmethod(get_llama_marlin)
, the rest code run properly.
Hi @efrantar, thanks a lot for sharing this optimized kernel!
I've given a try to them on A100, and the results are really nice for the regime with short sequences (<30-50). So so far for autoregressive decoding marlin would be the best in class kernel available as it maintains good latency even for batch sizes larger than one.
However, for longer sequence lengths (or large batch size, does not matter), both fp16xfp16 and exllamav2 outperform marlin kernel. AFAIK, exllamav2 just unpacks weights to fp16 in this case and calls cublas, I'm not sure why it is faster than pytorch native fp16x16.
It seems optimizing int4xfp16 kernels in GEMM regime is still open. So far we have different kernels best performing for some shapes, but they all require different packing, etc. which is not very handy.
One could still argue it is worth trading off more latency in the prefill to have lower one in the decoding, where usually we spend more time.
I have been making some benchmarks with Marlin, but the speed-up is far from what is reported. In fact, it's actually slower than fp16:
GPU: A6000 ada
matrix_shape: [11008, 4096]
input_shape: [1, 1024, 11008]
time (fp16): 0.0007191438674926758
time (marlin): 0.0006200861930847168 (1.16x)
input_shape: [16, 1024, 11008]
time (fp16): 0.010448209762573242
time (marlin): 0.01280400848388672 (0.82x)
Code below:
def forward_marlin(marlin_layer, x):
y = torch.empty(x.shape[:-1] + (marlin_layer.s.shape[1],), dtype=x.dtype, device=x.device)
marlin.mul(x.view((-1, x.shape[-1])), marlin_layer.B, y.view((-1, y.shape[-1])), marlin_layer.s, marlin_layer.workspace_fp)
return y
print(time_it(lambda: torch.matmul(x, ref) ))
print(time_it(lambda: forward_marlin(marlin_layer, x)))
What could be the issue ? Thanks in advance!
Why is Ampere or Ada (RTX 3000 and RTX 4000 series) required to support this?
Turing (RTX 2000 series) has INT4 tensor cores.
Hi,
I am trying to understand the motivation for shuffling the scale/weight order:
Line 125 in 3aa5a05
Line 133 in 3aa5a05
We have _scale_perm
being [0, 8, 16, 24, 32, 40, 48, 56,,, 1, 9, 17, 25, 33, 41, 49, 57,,, 2, 10, 18, 26, 34, 42, 50, 58,,, ...]
up to 64, and _perm = [0, 128, 8, 136, 16, 144, 24, 152,,, 256, 384, 264, 392, 272, 400, 280, 408, ...]
up to 1024.
At first I thought it might be related to the way data is owned by threads in mma.sync.aligned.m16n8k16
, but I don't think it is related.
Thank you!
Hi! You've probably already considered this, but would you be able to add support for Hopper H100 GPUs? A100s don't have nearly as much memory bandwidth. Am happy to run tests/benchmarks on one if that would help, thanks
Dear creators of Marlin
What a huge performance boost these kernels can bring! I’m super excited about this as the open source community has been lacking kernels that scale.
To my question, does Marlin support zero point quantization like we normally get from AutoGPTQ or AutoAWQ?
Best wishes
Casper
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
`
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par; //
int slice_iters; // number of threadblock tiles in the current slice
int slice_count = 0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to top
if (slice_col_par >= n_tiles) {
`
I have some questions about the code above. For example, if there are 108 SMs on the GPU and the calculated iters is 19, with blockIdx.x ranging from 0 to 127, is slice_col_par directly calculated based on iters=19? For instance, when blockIdx.x=5 or others, this thread block might not iterate 19 times.
In the Layer.pack
method, the shape of the scales should be (outfeatures, groups)?
Line 105 in 512f1b1
It seems that if we remove assert in the Layer.pack
, then we can pack an bf16 linear?
By the way, will marlin support "int4 \times bf16" as input?
Thanks for your wonderful work!
I am trying to understand matrix A's layout in shared memory. I think A's shape is (16 * thread_m_blocks) * (16 * thread_k_blocks)
in shared memory for every thread block originally, but the following code (line 287 in marlin_cuda_kernel.cu
) makes me confused. Why a_sh_rd_delta_o
is calculated like this? I'm looking forward to your reply.
Marlin looks great, but are there:
Even if there was a quantize.py script that would help a lot and make nn-vllm a possibility.
Hello everyone !
I am trying to understand how the marlin kernel works in depth to adapt it for int2 quantization, do you have any pointers please ? I appreciate your help !
can use one big kernel rather than many small kernel? maybe one kernel faster?
Marlin internally uses locks to synchronize the threads. This canresult in very slight nondeterminism for Marlin.
why?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.