Giter Club home page Giter Club logo

kernl's People

Contributors

ayoub-louati avatar els-lab-ci avatar gaetansnl avatar jonathlela avatar pommedeterresautee avatar thytu avatar white-gorilla avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

kernl's Issues

performs ONNX Runtime tests with cuda graphs enabled

Right now, enabling cuda graphs during onnx inference raise an unexpected exception and it should not be the case:

ERROR test/test_torchdynamo_bert.py - onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : This session cannot use the CUDA Graph feature as requested by the user  as all the graph nodes have...

It should not be the case as we only exec on CUDA.

Complete description below:
microsoft/onnxruntime#12977

Nothing to do but to follow the issue on ORT repo and report any fix or the answer.

Suivis des benchmarks dans le temps

Aujourd'hui nous n'avons pas de système de suivis des benchmarks dans le temps (en fonction des hardware et des commits). Dans la mesure où l'outil a vocation à être la référence en terme de mesures cette information est essentielle pour nous et les utilisateurs externes.

Différentes solutions existent et il ne semble pas très compliqué de faire une custom si cela était pertinent:

  • identifier une solution de suivis des performances
  • mettre en place la solution au niveau de github actions (ou proposer un point de benchmark alternatif)
  • rendre les résultats consultables depuis le repo

Add Flash attention Cuda implementation benchmark

Original Flash attention cuda implementation benchmark would be very interesting to check, for both speed and check our results accuracy.

Original implementation link: https://github.com/HazyResearch/flash-attention/
In the triton tutorial, they do it: https://triton-lang.org/master/getting-started/tutorials/06-fused-attention.html#sphx-glr-getting-started-tutorials-06-fused-attention-py

The main challenge, if any, is to install it (requires compilation and some dependencies)
May be do it in a test that can be disabled.

Add support of training to triton kernels

Aujourd'hui nous nous concentrons sur la phase FW (inférence).
Tout le matériel auquel nous avons accès en Triton propose out of the box la phase backward (training, calcul du gradient).
Afin de maintenir cet atout dans notre librairie, il est important de traiter le Bw en même temps que le travail sur le Fw:

  • réécrire le bw d'une manière lisible et maintenable
  • ajouter des unit tests
  • ajouter un benchmark (existe déjà dans Triton)

Dépendance :

More info in assert all close

To check outputs are similar to reference implementation we use torch all close function.
To help in debugging we may want to add more info, like what is done in Triton project:

  • max diff
  • max relative diff
  • nb mismatch elements

Example from Triton:

E       AssertionError: 
E       Arrays are not almost equal to 1 decimals
E       
E       Mismatched elements: 700185 / 786432 (89%)
E       Max absolute difference: 6.9219756
E       Max relative difference: 12265.675
E        x: array([[[[ 8.0e-02, -4.7e-01, -4.6e-02, ...,  8.7e-01,  7.6e-01,
E                 -2.5e-01],
E                [ 7.6e-02, -3.2e-01,  2.3e-01, ...,  2.1e-01,  2.3e-01,...
E        y: array([[[[ 8.0e-02, -4.7e-01, -4.6e-02, ...,  8.7e-01,  7.6e-01,
E                 -2.5e-01],
E                [ 7.3e-02, -2.3e-01,  4.0e-01, ..., -2.0e-01, -9.3e-02,...


change package name to kernl

original package name was nucle but has been changed to kernl
We need to update names and url addresses everywhere (readme, setup.py, messages)

Manage input mask in Flash Attention

Right now, we only support causal attention (with causal mask) and no input mask at all.
We need to support input mask even if non causal.

There is a perf trick regarding the causal mask, it doesn't load/compute data for part of text sequences above the mask diagonal.
We may want to reuse the trick for non causal mask. The main benefit would be to be able to infer on very large batch (something possible with flash attention which reduces memory footprint) and not care about padding!

The idea is simple: add mask management (easy) and if possible, make the inference perf as fast as possible by not loading and computing parts of the computation matrix hidden by the mask.

For that we can check in each loaded tile of the mask if the last column of the mask is fully hidden. If this is the case, it means we don't need to continue the computation for the following columns of those text sequences because they will only be made of padding.

Rendre la class CudaGraph plus générique

La technologie cuda graph est essentielle pour rendre le code Triton compétitif.
Cependant, elle requiert quelques subtilité pour être mises en oeuvre.
L'implémentation actuelle est liée à un kernel particulier sans bonne raison (si ce n'est d'avoir des premiers résultats rapidement).

Il faudrait transformer cette classe concrète en classe abstraite:

  • appeler n'importe quel kernel
  • gérer un nombre non défini d'input et d'output statiques
  • rajouter un benchmark à matmul avec triton + cuda graph
  • ajouter des unit tests dédiés (cf #4 )

https://developer.nvidia.com/blog/cuda-graphs/
https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/

benchmark of AITemplate

AITemplate has just been released and it's not clear if it worths being added to benchmark (aka is it faster than TensorRT).
They have strange options like accumulation in fp16 which may have some impact on the precision, run models in full fp16 (no mixed precision), etc.
This issue will be updated with notes about the benchmark.

Follow instructions from https://github.com/facebookincubator/AITemplate/tree/main/examples/03_bert

./docker/build.sh cuda
docker run --rm -it  --gpus all --network host -v $(pwd):/work ait
cd AITemplate/
# disable fast gelu
python3 examples/03_bert/benchmark_ait.py --activation gelu --batch-size 1 --seq-length 384
# batch_size: 1, seq_length: 384, latency: 2.213545501232147
python3 examples/03_bert/benchmark_ait.py --activation gelu --batch-size 1 --seq-length 384 --use-fp16-acc False
# batch_size: 1, seq_length: 384, latency: 2.4971184134483337
python3 examples/03_bert/benchmark_ait.py --activation gelu --batch-size 1 --seq-length 384 --use-fp16-acc False --encoders-only False
# batch_size: 1, seq_length: 384, latency: 2.481492042541504
python3 examples/03_bert/benchmark_pt.py --batch-size 1

# bert encoders pt: batch_size: 1, seq_length: 64, 10.690794677734376 ms
# bert encoders pt: batch_size: 1, seq_length: 128, 10.814188232421875 ms
# bert encoders pt: batch_size: 1, seq_length: 384, 10.125660400390625 ms
# bert encoders pt: batch_size: 1, seq_length: 512, 10.442783203125 ms
# bert encoders pt: batch_size: 1, seq_length: 1024, 10.131405029296875 ms
# bert encoders pt: batch_size: 1, seq_length: 4096, 54.720029296875 ms

attention kernel is not compiling for shape 32x32 (crashing)

When trying to compile our own implementation of flash attention or the original one, the following exception is raised:

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = JITFunction(implementations.attention_original:_fwd_kernel_original), arg_types = [pointer<fp16>, pointer<fp16>, pointer<fp16>, triton.language.fp32, pointer<fp32>, pointer<fp32>, ...], device = 0
attributes = {0: 16, 1: 16, 2: 16, 4: 16, ...}, constants = {11: 1, 15: 1, 19: 1, 23: 1, ...}, num_warps = 4, num_stages = 1, extern_libs = {}

    def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs):
        # create IR module
        context = _triton.ir.context()
        # get just-in-time proto-type of kernel
        arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
        ret_type = triton.language.void
        prototype = triton.language.function_type(ret_type, arg_types)
        # generate Triton-IR
        # export symbols visible from self into code-generator object
        gscope = self.__globals__
        generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
        try:
            generator.visit(self.parse())
        except Exception as e:
            node = generator.last_node
            if node is None or isinstance(e, (NotImplementedError, CompilationError)):
                raise e
            raise CompilationError(self.src, node) from e
        # Compile to machine code
        if torch.version.hip is None:
            backend = _triton.runtime.backend.CUDA
        else:
            backend = _triton.runtime.backend.ROCM
>       name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, extern_libs)
E       RuntimeError: Internal compiler error when trying to fuse matmuls!

It's linked to this code:

https://github.com/openai/triton/blob/c668d6596e45b7d5134f835c257b0ea627c2be0f/lib/codegen/selection/generator.cc#L2335

It's probably a bug inside Triton, we may want to try newer version to see if this is fixed, plus opening an issue in Triton repo.

wrapper to apply the optimization in 1 line (or 2?)

Right now, applying the optimizations requires to understand the lib.
We may want to wrap this logic and makes things simple.
It would take a model as input and return the function to run.
The provided model should be forbidden to run as we will mutate it (replace fw method for instance)

  • new API
  • update e2e Bert and T5 notebooks

Intégration des kernels customs Triton + Cuda graph dans un modèle e2e

Aujourd'hui nous travaillons au niveau des kernels.
Pour être utiles, ils doivent être intégrés dans des modèles existants.

Pour cela, il faut :

  • remplacer dans le graph FX du modèle PyTorch l'attention et les autres kernels optimisés
  • maintenir les capacités dynamiques du modèle (utiliser par exemple torchdynamo pour cela afin que le tracing de FX ne rende tout statique)
  • ajouter un benchmark dédié
  • ajouter les unit tests pertinents

Dépendance :

FP16 support for debugger

debugger recast all ops in fp16 to fp32 to run on CPU.
It makes us unable to perform some type promotion analysis.
FP16 tensor input should be supported.

Evaluate T5 implementaiton compatibility with currently implemented patterns

Nous avons actuellement des patterns qui remplacent les opérations de bert dans l'implémentaiton HF.
Nous voulons supporter T5.

Cette issue a pour but d'évaluer le travail nécessaire pour supporter le remplacement des opérations dans T5. Plus précisément vérifier si les patterns actuels sont compatibles et quel pattern supplémentaire est nécessaire.

Écrire des unit tests

Aujourd'hui les unit tests sont confondus avec les benchmarks.
Il est pertinent de garder des assert dans les benchmarks pour éviter des résultats faux ou trop approximatifs sur certaines tailles, etc.

Cependant il serait aussi pertinent d'avoir des unit tests dédiés qui testeraient d'autres situations non pertinentes pour les benchmarks.
Par exemple, qu'est ce qui se passe si on réutilise plusieurs fois la même instance CudaGraph avec différents inputs ?
Est-ce que le resizing des static inputs fonctionne comme attendu ?
Le broadcasting sur Triton ?
etc.

Avoir les outils/méthodes pour inspecter/comprendre l'exécution d'un graph FX.

Avant d'aller plus loin dans l'intégration de nouveaux kernels triton dans torchdynamo, on aimerait comprendre exactement comment fonctionne les graphs FX et pouvoir tracer leur vitesse et leur empreinte mémoire de façon précise dans le graphe.
Cela permettra par la suite de mieux benchmarker l'intégration des graph FX des anciens et nouveaux kernels.

Do not hardcode kernel tensor type

Right now, we force all our layers to take as input tensors in FP16.
For many of them we can make the type dynamic.
Matmul triton ops shows how to do it.

update Triton dependency

new version of triton include many fixes but breaks linear layer (autotune related).
we may want to update this dependency from time to time to be ready when the MLIR switch will happen and benefit from many fixes merged recently.

Comparer les sorties FP16 aux sorties FP32 de Pytorch

Aujourd'hui on compare la sortie Triton FP16 à la sortie PyTorch FP16 dans les benchmarks.
Notre référence devrait être PyTorch FP32, et nous devrions comparer l'écart entre [PT FP32 - PT FP16] Vs [PT FP32 - Triton FP16] < seuil.

Sur Flash Transformer, l'écart toléré est de 2x PT FP32 - PT FP16.

L'objectif est d'avoir une référence fiable car on ne sait pas si PyTorch FP16 a toujours une bonne précision, il ne devrait pas être notre référence.
En effet, il existe un standard pour stocker et manipuler un nombre en FP16 mais changer l'ordre des opérations en FP16 change le résultat, y compris si l'opération est commutative.

https://en.wikipedia.org/wiki/Half-precision_floating-point_format

Flash attention crashs on turing hardware.

When running flash attention with Tesla T4, segfault occurs:

____________________________________________________________________________ test_benchmark_masked[triton-bias-mask-causal-bf16-8x128] ____________________________________________________________________________

heads = 48, seq_length = 128
Q = tensor([[[[3.9844e-01, 5.1562e-01, 2.4902e-02,  ..., 6.7969e-01,
           4.4141e-01, 5.0000e-01],
          [8.0859...6.7188e-01, 1.5918e-01,  ..., 3.0859e-01,
           7.9688e-01, 5.4297e-01]]]], device='cuda:0', dtype=torch.bfloat16)
K = tensor([[[[3.2422e-01, 9.5703e-01, 3.1055e-01,  ..., 5.3516e-01,
           4.9023e-01, 1.9531e-01],
          [7.7344...5.2734e-01, 3.8477e-01,  ..., 1.2891e-01,
           7.2266e-01, 2.4292e-02]]]], device='cuda:0', dtype=torch.bfloat16)
V = tensor([[[[2.0801e-01, 6.9531e-01, 7.9297e-01,  ..., 5.6641e-01,
           6.9922e-01, 4.4678e-02],
          [8.7891...4.0625e-01, 2.8906e-01,  ..., 1.2158e-01,
           8.0078e-01, 2.4805e-01]]]], device='cuda:0', dtype=torch.bfloat16)
sm_scale = 0.3
attention_mask = tensor([[[[4.0234e-01, 7.2266e-01, 9.6094e-01,  ..., 9.6094e-01,
           1.9141e-01, 5.4297e-01],
          [3.8672...5.3125e-01, 2.9297e-01,  ..., 6.0156e-01,
           1.4062e-01, 8.4766e-01]]]], device='cuda:0', dtype=torch.bfloat16)
TMP = tensor([[2.3694e-38, 2.3694e-38, 2.3694e-38,  ..., 2.3694e-38, 2.3694e-38,
         2.3694e-38],
        [0.0000e+00, ...00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]], device='cuda:0')
output = tensor([[[[0.2080, 0.6953, 0.7930,  ..., 0.5664, 0.6992, 0.0447],
          [0.6250, 0.3262, 0.8320,  ..., 0.4258, 0.3...47],
          [0.5391, 0.4727, 0.5273,  ..., 0.5352, 0.4629, 0.5781]]]],
       device='cuda:0', dtype=torch.bfloat16)
q_batch_stride = 393216, q_head_stride = 8192, q_m_stride = 64, q_k_stride = 1, k_batch_stride = 393216, k_head_stride = 8192, k_n_stride = 64, k_k_stride = 1, v_batch_stride = 393216, v_head_stride = 8192
v_k_stride = 64, v_n_stride = 1, o_batch_stride = 393216, o_head_stride = 8192, o_m_stride = 64, o_n_stride = 1, attention_mask_batch_stride = 786432, attention_mask_head_stride = 16384
attention_mask_m_stride = 128, attention_mask_k_stride = 1, min_clamp_value = -3.3895313892515355e+38, MASK_BATCH_SIZE = 8, MASK_HEAD_SIZE = 48, MASK_M_SIZE = 128, MASK_K_SIZE = 128, HAS_MASK = True
IS_CAUSAL = True, BLOCK_M = 128, BLOCK_DHEAD = 64, BLOCK_N = 128, grid = (1, 384), num_warps = 4, num_stages = 1, extern_libs = None, stream = 0, warmup = False

>   ???
E   KeyError: ('2-.-0-.-0-1e8410f206c822547fb50e2ea86e45a6-cfed90d463fc30ccffb0eb2fd26372d3-a357695982511d203a134df772c7b4a1-2121719c12e3ab66746f4a57f276d42e-0f76008a374e725ca29ccb33f1ba668f-dc48432b6b79843e2f9c7ad2e7355f59-f40d73592c2578180d3d8e3f64e3957d-0dd03b0bd512a184b3512b278d9dfa59-d7c2e52f8151bec157e9a17a1ec37dd3', ('i32', 'i32', torch.bfloat16, torch.bfloat16, torch.bfloat16, 'fp32', torch.bfloat16, torch.float32, torch.bfloat16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'fp32'), (8, 48, 128, 128, True, True, 128, 64, 128), ((True, False), (True, False), True, True, True, (False,), True, True, True, (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (False,)))

<string>:21: KeyError

During handling of the above exception, another exception occurred:

benchmark = <kernl.benchmark.benchmark_fixture.BenchmarkFixture object at 0x7eff5c2b8e20>, shape = (8, 128), implementation = 'triton', mask_fn = <function generate_bias_mask at 0x7eff5c39d700>
dtype = torch.bfloat16, is_causal = True

    @set_seed()
    @pytest.mark.parametrize(
        "shape",
        [(bs, seq_l) for bs in [8, 32, 64] for seq_l in [128, 256, 384, 512]] + [(32, 32)],
        ids=lambda x: f"{x[0]}x{x[1]}",
    )
    # fp32 not yet possible because of a bug in triton
    @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
    @pytest.mark.parametrize("is_causal", [True, False], ids=["causal", "non-causal"])
    @pytest.mark.parametrize(
        "mask_fn",
        [generate_bias_mask, generate_broadcast_mask, generate_none_mask],
        ids=["bias-mask", "broadcast-mask", "no-mask"],
    )
    @pytest.mark.parametrize("implementation", implementations.keys())
    def test_benchmark_masked(
        benchmark, shape: (int, int), implementation: Callable, mask_fn: Callable, dtype: torch.dtype, is_causal: bool
    ):
        batch, seq_length = shape
        if implementation == "original" and (dtype == torch.bfloat16 or seq_length != 512):
            pytest.skip("Original Triton implementation only supports fp16 and seq_length=512")
        elif implementation == "original" and mask_fn != generate_none_mask:
            pytest.skip("Original Triton implementation doesn't support masks")
    
        # batch, heads, seq_length, dhead
        mat_shape = (batch, 48, seq_length, 64)
        args = {
            "q": torch.rand(mat_shape, device="cuda"),
            "k": torch.rand(mat_shape, device="cuda"),
            "v": torch.rand(mat_shape, device="cuda"),
            "output": torch.empty(mat_shape, device="cuda"),
            "sm_scale": 0.3,  # Scaling applied before softmax (sqrt(dhead) in Vaswani et al.)
            "is_causal": is_causal,
            "attention_mask": mask_fn(batch, seq_length),
        }
    
        expected = attention_reference(**args)
        cast_args = {k: v.to(dtype).clone() if isinstance(v, torch.Tensor) else v for k, v in args.items()}
    
        func = implementations[implementation]
>       value = benchmark(func, **cast_args)

test/test_attention.py:107: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/kernl/benchmark/benchmark_fixture.py:53: in __call__
    function_to_benchmark(*args, **kwargs)
test/test_attention.py:39: in <lambda>
    "triton": lambda q, k, v, output, sm_scale, is_causal, attention_mask: attention_forward(
src/kernl/implementations/attention.py:436: in attention_forward
    return Attention.apply(q, k, v, output, sm_scale, is_causal, attention_mask)
/usr/local/lib/python3.9/dist-packages/torch/cuda/amp/autocast_mode.py:118: in decorate_fwd
    return fwd(*args, **kwargs)
src/kernl/implementations/attention.py:380: in forward
    _fwd_kernel[grid](
/usr/local/lib/python3.9/dist-packages/triton/runtime/jit.py:106: in launcher
    return self.run(*args, grid=grid, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

heads = 48, seq_length = 128
Q = tensor([[[[3.9844e-01, 5.1562e-01, 2.4902e-02,  ..., 6.7969e-01,
           4.4141e-01, 5.0000e-01],
          [8.0859...6.7188e-01, 1.5918e-01,  ..., 3.0859e-01,
           7.9688e-01, 5.4297e-01]]]], device='cuda:0', dtype=torch.bfloat16)
K = tensor([[[[3.2422e-01, 9.5703e-01, 3.1055e-01,  ..., 5.3516e-01,
           4.9023e-01, 1.9531e-01],
          [7.7344...5.2734e-01, 3.8477e-01,  ..., 1.2891e-01,
           7.2266e-01, 2.4292e-02]]]], device='cuda:0', dtype=torch.bfloat16)
V = tensor([[[[2.0801e-01, 6.9531e-01, 7.9297e-01,  ..., 5.6641e-01,
           6.9922e-01, 4.4678e-02],
          [8.7891...4.0625e-01, 2.8906e-01,  ..., 1.2158e-01,
           8.0078e-01, 2.4805e-01]]]], device='cuda:0', dtype=torch.bfloat16)
sm_scale = 0.3
attention_mask = tensor([[[[4.0234e-01, 7.2266e-01, 9.6094e-01,  ..., 9.6094e-01,
           1.9141e-01, 5.4297e-01],
          [3.8672...5.3125e-01, 2.9297e-01,  ..., 6.0156e-01,
           1.4062e-01, 8.4766e-01]]]], device='cuda:0', dtype=torch.bfloat16)
TMP = tensor([[2.3694e-38, 2.3694e-38, 2.3694e-38,  ..., 2.3694e-38, 2.3694e-38,
         2.3694e-38],
        [0.0000e+00, ...00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]], device='cuda:0')
output = tensor([[[[0.2080, 0.6953, 0.7930,  ..., 0.5664, 0.6992, 0.0447],
          [0.6250, 0.3262, 0.8320,  ..., 0.4258, 0.3...47],
          [0.5391, 0.4727, 0.5273,  ..., 0.5352, 0.4629, 0.5781]]]],
       device='cuda:0', dtype=torch.bfloat16)
q_batch_stride = 393216, q_head_stride = 8192, q_m_stride = 64, q_k_stride = 1, k_batch_stride = 393216, k_head_stride = 8192, k_n_stride = 64, k_k_stride = 1, v_batch_stride = 393216, v_head_stride = 8192
v_k_stride = 64, v_n_stride = 1, o_batch_stride = 393216, o_head_stride = 8192, o_m_stride = 64, o_n_stride = 1, attention_mask_batch_stride = 786432, attention_mask_head_stride = 16384
attention_mask_m_stride = 128, attention_mask_k_stride = 1, min_clamp_value = -3.3895313892515355e+38, MASK_BATCH_SIZE = 8, MASK_HEAD_SIZE = 48, MASK_M_SIZE = 128, MASK_K_SIZE = 128, HAS_MASK = True
IS_CAUSAL = True, BLOCK_M = 128, BLOCK_DHEAD = 64, BLOCK_N = 128, grid = (1, 384), num_warps = 4, num_stages = 1, extern_libs = None, stream = 0, warmup = False

>   ???
E   RuntimeError: Triton Error [CUDA]: invalid argument

<string>:44: RuntimeError

test any input len (aka not just shapes of power of 2)

IRL input tensor shapes are not all power of 2 but our benchmark only test those shapes.

The purpose of this issue is to check if it's possible to take as input any seq len.
Right now it's not possible as in attention kernel the block M/N size is related to the seq len (block M == block N for now).
Seq len has to be a multiple of it, we can't change this requirement without updating the kernel because the kernel doesn't manage cache during loading and saving of tensors.
And Triton only accepts blocks that have a len which is a power of 2 otherwise an Exception is raised: "Triton tensors must have a power-of-two number of elements".
Moreover, small seq len (< 16) won't work too, even if power of 2 as min size of a block is 16x16

In the existing code, there is a comment showing that the issue is already known.

Best solution is to manage masking at load/save times.
If this is the case, when a seq len is not a power of 2, we can simply select the next power of 2 (limited to 128).

@gaetansnl

T5 support

We implement the support of T5 models.
In this first step it's only kernel replacement, no other optimizations.

  • Support T5 masks #66
  • Support T5 kernels patterns #63
  • Support T5 missing activation #54
  • Support T5 #77

reproductible run

third party may want to have a docker image and an associated command to run all benchmarks (do not require any knowledge of the dependencies, etc. including tensorRT).
@gaetansnl Can we put the image we use on K8S in this repo?

add split K support on Linear layer (if easy)

Split K in matmul distributes work accross SM on the K axis. By itself, it's quite easy to do in Triton, a bit less if we want to manage adding bias and activation function. It will require an (atomic) counter outside of the triton program (as a tensor) to know where we are on the K axis.

Improve debugger API

Currently we need to pass the list of tensors at the beggining when we create the debugger instance.
We can progressively create the map when we call get_ptr.

Benchmarker autre chose que des shapes "adaptées" aux GPU

Aujourd'hui on ne benchmark que des tailles qui sont des multiples de 128.
Ça colle parfaitement aux préférences des GPU.
Cependant on ne test/benchmark pas de shapes avec des nombres impares, etc.
Ces dernières mettraient à contribution les masques au load/write, etc.

Fix remove_dropout replacement using names

Currently the function remove_dropout uses name matching to detect and remove dropout.
It's not reliable.

This issue is about finding another way to match them and delete them.

T5 mask generation

T5 mask includes bias and takes lots of space in GM. We may generate it directly inside the Triton kernel to reduce the footprint and reduces e2e model latency.

Test benchmarks on Turing hardware

Right now we make our tests on Ampere GPU microarchitecture. We want to check against the previous nvidia microarchitecture (Turing) to see how it behaves with older GPU.

Make triton code debuggable

Right now the triton kernels can't be easily debugged as they are executed in parallel without any easy way to use a debugger or print values (outside of vec shapes). Triton code looking like numpy / torch code, we can try to execute serially using one of those tools.

  • implement a serial runner of triton kernel
  • unit test triton tutorial kernels
  • add simple analytical measures(read / write GM)

Implement efficient single pass on data variance computation in layernorm kernel

Current kernel from xformers / triton tutorial perform 2 passes on data to compute mean and then variance.
Welford formula offers the possibility to do it in a single pass, which would decrease pressure on memory bandwidth.

useful links:
https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/
https://www.johndcook.com/blog/standard_deviation/

with implementation: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance , https://www.embeddedrelated.com/showarticle/785.php
clear explanation: https://changyaochen.github.io/welford/

normalize triton version

Official triton on pipy is not up to date.
We need to install a pre-release, to guarantee we have the same, we need to pinpoint it in requirements.

Fp16 cache for mixed precision does a mutation of input

We need a cache of fp16 converted weights (for linear for example) to avoid many conversion, because in torch the cache is a wrapper around the linear kernel dans we bypass the linear kernel.

Our current solution is to mutate the original linear layer to use it as cache, then we do 2 rounds of CUDA graph warmump so it picks the correct input. We when set the previous weights to None to allows for GC (we did not test this part, so we are unsure about this part)

Proposal:

  • One method to avoid that could be to a cache module in the replacement pattern, the cache will stay in the new graph, and we can pass it to the kernel wrapper
  • We could also use a global cache that we add to the main graph module, then we pass it to the wrappers using the module repalcement system with hardcoded replacement for cache module.

Fusionner les autres layers de transformers

Aujourd'hui nous nous concentrons sur l'attention elle même.
D'autres layers pourrait être fusionnés, d'autant plus qu'en dehors des gains d'accès à la mémoire, il semblerait que les matmul Triton + cuda graph soient parfois plus rapides que sur PyTorch (en fonction de la shape):

  • residual connection + layernorm [ADD + NORM]
  • feed forward + layernorm
  • ajouter les benchmarks, unit tests, etc.

voir schéma https://github.com/markriedl/transformer-walkthrough

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.