Giter Club home page Giter Club logo

Comments (17)

wujingyue avatar wujingyue commented on August 21, 2024 2

With some more hacks (which I'll try to find a way to submit), I'm getting some useful nvFusions to hopefully start with. Now the forward pass runs two nvFusions. The first one has one fd.ops.linear, which I suspect is the input linear layer. The second one has three fd.ops.linear, which I suspect is the output linear layer followed by the two-layer MLP.

I'll confirm this and try to include SDPA as well.

$ NVFUSER_DUMP=python_definition NVFUSER_DISABLE=matmul_expr_eval pytest thunder/benchmarks/targets.py -k test_nanogpt_block_fwd[thunder] -s
============================================================================================================================ test session starts =============================================================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.4.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: timestamper-0.0.10, xdist-3.5.0, random-order-1.1.1, cov-4.1.0, benchmark-4.0.0, hypothesis-6.100.0, timeout-2.2.0, shard-0.1.2
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 162 items / 161 deselected / 1 selected
Running 1 items in this shard

thunder/benchmarks/targets.py
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.ops.cast(T4, dtype=DataType.Float)
    T6, T7 = fd.ops.var_mean(T5, dims=[2], correction=0, keepdim=False)
    S8 = fd.define_scalar(16, dtype=DataType.Int)
    S9 = fd.define_scalar(128, dtype=DataType.Int)
    S10 = fd.define_scalar(1, dtype=DataType.Int)
    V11 = fd.define_vector([S8, S9, S10], dtype=DataType.Int)
    T12 = fd.ops.broadcast_in_dim(T6, shape=V11, broadcast_dims=[0, 1])
    S13 = fd.define_scalar(16, dtype=DataType.Int)
    S14 = fd.define_scalar(128, dtype=DataType.Int)
    S15 = fd.define_scalar(1, dtype=DataType.Int)
    V16 = fd.define_vector([S13, S14, S15], dtype=DataType.Int)
    T17 = fd.ops.broadcast_in_dim(T7, shape=V16, broadcast_dims=[0, 1])
    S18 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T19 = fd.ops.add(T12, S18)
    T20 = fd.ops.rsqrt(T19)
    S21 = fd.define_scalar(16, dtype=DataType.Int)
    S22 = fd.define_scalar(128, dtype=DataType.Int)
    S23 = fd.define_scalar(1600, dtype=DataType.Int)
    V24 = fd.define_vector([S21, S22, S23], dtype=DataType.Int)
    T25 = fd.ops.broadcast_in_dim(T17, shape=V24, broadcast_dims=[0, 1, 2])
    T26 = fd.ops.sub(T5, T25)
    S27 = fd.define_scalar(16, dtype=DataType.Int)
    S28 = fd.define_scalar(128, dtype=DataType.Int)
    S29 = fd.define_scalar(1600, dtype=DataType.Int)
    V30 = fd.define_vector([S27, S28, S29], dtype=DataType.Int)
    T31 = fd.ops.broadcast_in_dim(T20, shape=V30, broadcast_dims=[0, 1, 2])
    T32 = fd.ops.mul(T26, T31)
    S33 = fd.define_scalar(16, dtype=DataType.Int)
    S34 = fd.define_scalar(128, dtype=DataType.Int)
    S35 = fd.define_scalar(1600, dtype=DataType.Int)
    V36 = fd.define_vector([S33, S34, S35], dtype=DataType.Int)
    T37 = fd.ops.broadcast_in_dim(T3, shape=V36, broadcast_dims=[2])
    T38 = fd.ops.cast(T37, dtype=DataType.Float)
    T39 = fd.ops.mul(T32, T38)
    S40 = fd.define_scalar(16, dtype=DataType.Int)
    S41 = fd.define_scalar(128, dtype=DataType.Int)
    S42 = fd.define_scalar(1600, dtype=DataType.Int)
    V43 = fd.define_vector([S40, S41, S42], dtype=DataType.Int)
    T44 = fd.ops.broadcast_in_dim(T2, shape=V43, broadcast_dims=[2])
    T45 = fd.ops.cast(T44, dtype=DataType.Float)
    T46 = fd.ops.add(T39, T45)
    T47 = fd.ops.cast(T46, dtype=DataType.BFloat16)
    S48 = fd.define_scalar(2048, dtype=DataType.Int)
    S49 = fd.define_scalar(1600, dtype=DataType.Int)
    V50 = fd.define_vector([S48, S49], dtype=DataType.Int)
    T51 = fd.ops.reshape(T47, new_shape=V50)
    T52 = fd.ops.linear(T51, T1, T0)
    S53 = fd.define_scalar(16, dtype=DataType.Int)
    S54 = fd.define_scalar(128, dtype=DataType.Int)
    S55 = fd.define_scalar(4800, dtype=DataType.Int)
    V56 = fd.define_vector([S53, S54, S55], dtype=DataType.Int)
    T57 = fd.ops.reshape(T52, new_shape=V56)
    T58 = fd.ops.slice(T57, start_indices=[0, 0, 0], end_indices=[16, 128, 1600], strides=[1, 1, 1])
    T59 = fd.ops.slice(T57, start_indices=[0, 0, 1600], end_indices=[16, 128, 3200], strides=[1, 1, 1])
    T60 = fd.ops.slice(T57, start_indices=[0, 0, 3200], end_indices=[16, 128, 4800], strides=[1, 1, 1])
    S61 = fd.define_scalar(16, dtype=DataType.Int)
    S62 = fd.define_scalar(128, dtype=DataType.Int)
    S63 = fd.define_scalar(25, dtype=DataType.Int)
    S64 = fd.define_scalar(64, dtype=DataType.Int)
    V65 = fd.define_vector([S61, S62, S63, S64], dtype=DataType.Int)
    T66 = fd.ops.reshape(T59, new_shape=V65)
    T67 = fd.ops.permute(T66, dims=[0, 2, 1, 3])
    S68 = fd.define_scalar(16, dtype=DataType.Int)
    S69 = fd.define_scalar(128, dtype=DataType.Int)
    S70 = fd.define_scalar(25, dtype=DataType.Int)
    S71 = fd.define_scalar(64, dtype=DataType.Int)
    V72 = fd.define_vector([S68, S69, S70, S71], dtype=DataType.Int)
    T73 = fd.ops.reshape(T58, new_shape=V72)
    T74 = fd.ops.permute(T73, dims=[0, 2, 1, 3])
    S75 = fd.define_scalar(16, dtype=DataType.Int)
    S76 = fd.define_scalar(128, dtype=DataType.Int)
    S77 = fd.define_scalar(25, dtype=DataType.Int)
    S78 = fd.define_scalar(64, dtype=DataType.Int)
    V79 = fd.define_vector([S75, S76, S77, S78], dtype=DataType.Int)
    T80 = fd.ops.reshape(T60, new_shape=V79)
    T81 = fd.ops.permute(T80, dims=[0, 2, 1, 3])
    fd.add_output(T74)
    fd.add_output(T67)
    fd.add_output(T81)

[W509 16:47:44.547956141 matmul_utils.cpp:386] Warning: Scheduling a matmul without heuristic plugin. Specify plugin location like this: NVFUSER_MATMUL_HEURISTIC_PLUGIN=/path/to/libmatmulheuristic.so (function operator())

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T5 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T6 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T7 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T8 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T9 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T10 = fd.ops.permute(T9, dims=[0, 2, 1, 3])
    T11 = fd.ops.stride_order(T10, stride_order=[3, 2, 1, 0])
    S12 = fd.define_scalar(16, dtype=DataType.Int)
    S13 = fd.define_scalar(128, dtype=DataType.Int)
    S14 = fd.define_scalar(1600, dtype=DataType.Int)
    V15 = fd.define_vector([S12, S13, S14], dtype=DataType.Int)
    T16 = fd.ops.reshape(T11, new_shape=V15)
    S17 = fd.define_scalar(2048, dtype=DataType.Int)
    S18 = fd.define_scalar(1600, dtype=DataType.Int)
    V19 = fd.define_vector([S17, S18], dtype=DataType.Int)
    T20 = fd.ops.reshape(T16, new_shape=V19)
    T21 = fd.ops.linear(T20, T1, T0)
    S22 = fd.define_scalar(16, dtype=DataType.Int)
    S23 = fd.define_scalar(128, dtype=DataType.Int)
    S24 = fd.define_scalar(1600, dtype=DataType.Int)
    V25 = fd.define_vector([S22, S23, S24], dtype=DataType.Int)
    T26 = fd.ops.reshape(T21, new_shape=V25)
    S27 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S28 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S29 = fd.define_scalar(16, dtype=DataType.Int)
    S30 = fd.define_scalar(128, dtype=DataType.Int)
    S31 = fd.define_scalar(1600, dtype=DataType.Int)
    V32 = fd.define_vector([S29, S30, S31], dtype=DataType.Int)
    T33 = fd.ops.uniform(S27, S28, shape=V32, dtype=DataType.BFloat16)
    S34 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T35 = fd.ops.lt(T33, S34)
    T36 = fd.ops.cast(T26, dtype=DataType.Float)
    T37 = fd.ops.cast(T35, dtype=DataType.Float)
    T38 = fd.ops.mul(T36, T37)
    S39 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T40 = fd.ops.mul(T38, S39)
    T41 = fd.ops.cast(T8, dtype=DataType.Float)
    T42 = fd.ops.add(T41, T40)
    T43, T44 = fd.ops.var_mean(T42, dims=[2], correction=0, keepdim=False)
    S45 = fd.define_scalar(16, dtype=DataType.Int)
    S46 = fd.define_scalar(128, dtype=DataType.Int)
    S47 = fd.define_scalar(1, dtype=DataType.Int)
    V48 = fd.define_vector([S45, S46, S47], dtype=DataType.Int)
    T49 = fd.ops.broadcast_in_dim(T43, shape=V48, broadcast_dims=[0, 1])
    S50 = fd.define_scalar(16, dtype=DataType.Int)
    S51 = fd.define_scalar(128, dtype=DataType.Int)
    S52 = fd.define_scalar(1, dtype=DataType.Int)
    V53 = fd.define_vector([S50, S51, S52], dtype=DataType.Int)
    T54 = fd.ops.broadcast_in_dim(T44, shape=V53, broadcast_dims=[0, 1])
    S55 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T56 = fd.ops.add(T49, S55)
    T57 = fd.ops.rsqrt(T56)
    S58 = fd.define_scalar(16, dtype=DataType.Int)
    S59 = fd.define_scalar(128, dtype=DataType.Int)
    S60 = fd.define_scalar(1600, dtype=DataType.Int)
    V61 = fd.define_vector([S58, S59, S60], dtype=DataType.Int)
    T62 = fd.ops.broadcast_in_dim(T54, shape=V61, broadcast_dims=[0, 1, 2])
    T63 = fd.ops.sub(T42, T62)
    S64 = fd.define_scalar(16, dtype=DataType.Int)
    S65 = fd.define_scalar(128, dtype=DataType.Int)
    S66 = fd.define_scalar(1600, dtype=DataType.Int)
    V67 = fd.define_vector([S64, S65, S66], dtype=DataType.Int)
    T68 = fd.ops.broadcast_in_dim(T57, shape=V67, broadcast_dims=[0, 1, 2])
    T69 = fd.ops.mul(T63, T68)
    S70 = fd.define_scalar(16, dtype=DataType.Int)
    S71 = fd.define_scalar(128, dtype=DataType.Int)
    S72 = fd.define_scalar(1600, dtype=DataType.Int)
    V73 = fd.define_vector([S70, S71, S72], dtype=DataType.Int)
    T74 = fd.ops.broadcast_in_dim(T3, shape=V73, broadcast_dims=[2])
    T75 = fd.ops.cast(T74, dtype=DataType.Float)
    T76 = fd.ops.mul(T69, T75)
    S77 = fd.define_scalar(16, dtype=DataType.Int)
    S78 = fd.define_scalar(128, dtype=DataType.Int)
    S79 = fd.define_scalar(1600, dtype=DataType.Int)
    V80 = fd.define_vector([S77, S78, S79], dtype=DataType.Int)
    T81 = fd.ops.broadcast_in_dim(T2, shape=V80, broadcast_dims=[2])
    T82 = fd.ops.cast(T81, dtype=DataType.Float)
    T83 = fd.ops.add(T76, T82)
    T84 = fd.ops.cast(T83, dtype=DataType.BFloat16)
    S85 = fd.define_scalar(2048, dtype=DataType.Int)
    S86 = fd.define_scalar(1600, dtype=DataType.Int)
    V87 = fd.define_vector([S85, S86], dtype=DataType.Int)
    T88 = fd.ops.reshape(T84, new_shape=V87)
    T89 = fd.ops.linear(T88, T5, T4)
    S90 = fd.define_scalar(16, dtype=DataType.Int)
    S91 = fd.define_scalar(128, dtype=DataType.Int)
    S92 = fd.define_scalar(6400, dtype=DataType.Int)
    V93 = fd.define_vector([S90, S91, S92], dtype=DataType.Int)
    T94 = fd.ops.reshape(T89, new_shape=V93)
    T95 = fd.ops.cast(T94, dtype=DataType.Float)
    T96 = fd.ops.mul(T95, T95)
    T97 = fd.ops.mul(T96, T95)
    S98 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T99 = fd.ops.mul(S98, T95)
    S100 = fd.define_scalar(0.0447150, dtype=DataType.Double)
    T101 = fd.ops.mul(S100, T97)
    T102 = fd.ops.add(T95, T101)
    S103 = fd.define_scalar(0.797885, dtype=DataType.Double)
    T104 = fd.ops.mul(S103, T102)
    T105 = fd.ops.tanh(T104)
    S106 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T107 = fd.ops.add(S106, T105)
    T108 = fd.ops.mul(T99, T107)
    T109 = fd.ops.cast(T108, dtype=DataType.BFloat16)
    S110 = fd.define_scalar(2048, dtype=DataType.Int)
    S111 = fd.define_scalar(6400, dtype=DataType.Int)
    V112 = fd.define_vector([S110, S111], dtype=DataType.Int)
    T113 = fd.ops.reshape(T109, new_shape=V112)
    T114 = fd.ops.linear(T113, T7, T6)
    S115 = fd.define_scalar(16, dtype=DataType.Int)
    S116 = fd.define_scalar(128, dtype=DataType.Int)
    S117 = fd.define_scalar(1600, dtype=DataType.Int)
    V118 = fd.define_vector([S115, S116, S117], dtype=DataType.Int)
    T119 = fd.ops.reshape(T114, new_shape=V118)
    S120 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S121 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S122 = fd.define_scalar(16, dtype=DataType.Int)
    S123 = fd.define_scalar(128, dtype=DataType.Int)
    S124 = fd.define_scalar(1600, dtype=DataType.Int)
    V125 = fd.define_vector([S122, S123, S124], dtype=DataType.Int)
    T126 = fd.ops.uniform(S120, S121, shape=V125, dtype=DataType.BFloat16)
    S127 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T128 = fd.ops.lt(T126, S127)
    T129 = fd.ops.cast(T119, dtype=DataType.Float)
    T130 = fd.ops.cast(T128, dtype=DataType.Float)
    T131 = fd.ops.mul(T129, T130)
    S132 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T133 = fd.ops.mul(T131, S132)
    T134 = fd.ops.add(T42, T133)
    T135 = fd.ops.cast(T134, dtype=DataType.BFloat16)
    fd.add_output(T135)

from fuser.

wujingyue avatar wujingyue commented on August 21, 2024 1

Note to myself: I'll first try to get a single-device nvFuser python definition from Thunder, and then we can manually shard it using nvFuser's API.

@Priya2698 pointed me to the nv_enable_linear flag (https://github.com/Lightning-AI/lightning-thunder/blob/90a0f4c0d0a90d1e94684a847f3adfe2230985b4/thunder/tests/test_nvfuser.py#L875) that I'll need to turn on to enable prims.linear via nvFuser. I'll probably need to nv_enable_bookend=False as well.

from fuser.

wujingyue avatar wujingyue commented on August 21, 2024 1

at present, we assume the same in pattern matching as well (BroadcastOp -> MmaOp-> CastOp)

That's right. I already merged #2221, so you can reproduce this by running the reproducer in #2199 (comment). Anyhow, I'll run my experiments with matmul_expr_eval disabled, so #2221 is sufficient to unblock me at this moment.

That being said, there's a variation of the same problem for the ATen evaluation: the segmenter doesn't guarantee to put cast into the same segment, just as it didn't put broadcast that way. I think the new IR nodes will help that but I'm not sure and I'll leave that to you.

from fuser.

wujingyue avatar wujingyue commented on August 21, 2024 1

@cowanmeg Lightning-AI/lightning-thunder@bf84b04 checked in what's in the forward pass of a single-device transformer block modulo SDPA. See the message of that commit for more details. With that, we should be able to work on this in parallel. I'll try to include SPDA and backprop, and you'll try to build a sharded version. How does that sound?

from fuser.

wujingyue avatar wujingyue commented on August 21, 2024

Note to myself: I'll start with the following benchmark

$ pytest thunder/benchmarks/targets.py -k test_nanogpt_block_fwd[thunder] -s

which exercises one transformer layer in nanoGPT:
https://github.com/Lightning-AI/lightning-thunder/blob/cab020881765594fd9552d4deb8cc4e0f64410d2/thunder/tests/nanogpt_model.py#L132-L143

from fuser.

wujingyue avatar wujingyue commented on August 21, 2024

cc @Priya2698

a.ndim==2 is the first check that failed. Here's how you can reproduce the problem:

With the following patch

diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index c955da06..4767ab9c 100644
--- a/thunder/executors/nvfuserex_impl.py
+++ b/thunder/executors/nvfuserex_impl.py
@@ -2201,6 +2201,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
         return False

     enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser matmul.")
+    enable_linear = True
     if not enable_linear:
         return False
     # Verify linear inputs and bias (optional) are supported tensors.
@@ -2210,6 +2211,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
         return False

     # nvFuser only supports 2D inputs in v0.2.3.
+    import pdb; pdb.set_trace()
     if not a.ndim == 2:
         return False
     return True
$ NVFUSER_DUMP=python_definition pytest thunder/benchmarks/targets.py -k test_nanogpt_block_fwd[thunder] -s
========================================================================================================================================================================================================================================= test session starts =========================================================================================================================================================================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.4.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: timestamper-0.0.10, xdist-3.5.0, random-order-1.1.1, cov-4.1.0, benchmark-4.0.0, hypothesis-6.100.0, timeout-2.2.0, shard-0.1.2
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 162 items / 161 deselected / 1 selected
Running 1 items in this shard

thunder/benchmarks/targets.py
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> PDB set_trace >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
> /opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py(2215)_linear_check()
-> if not a.ndim == 2:
(Pdb) p a.ndim
3
(Pdb)

The Python definition printed out is unsurprisingly five fusions, none of which have matmul or linear.

from fuser.

wujingyue avatar wujingyue commented on August 21, 2024

Below is a WAR for the above Thunder check but it ran into an nvFuser issue.

diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index c955da06..137da102 100644
--- a/thunder/executors/nvfuserex_impl.py
+++ b/thunder/executors/nvfuserex_impl.py
@@ -4,6 +4,7 @@ from numbers import Number
 from typing import Union, List, Any, Optional, Dict, Set, Tuple, Type
 from types import NoneType
 from collections.abc import Callable, Mapping, Hashable, Sequence
+import math
 import os
 import time
 from copy import copy
@@ -796,7 +797,7 @@ instantiated) this heuristic actually leads to worse code.
             enable_bookend: None | bool = get_compile_option("nv_enable_bookend", bookend_help)
             # Set default value.
             if enable_bookend is None:
-                enable_bookend = True
+                enable_bookend = False
             assert isinstance(enable_bookend, bool)

             if enable_bookend:
@@ -2200,7 +2201,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
     if nv_version < LooseVersion("0.2.3"):
         return False

-    enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser matmul.")
+    enable_linear = True
     if not enable_linear:
         return False
     # Verify linear inputs and bias (optional) are supported tensors.
@@ -2209,8 +2210,11 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
     if bias is not None and not is_supported_tensor(bias):
         return False

-    # nvFuser only supports 2D inputs in v0.2.3.
-    if not a.ndim == 2:
+    if a.ndim < 2:
+        return False
+    if b.ndim != 2:
+        return False
+    if bias.ndim != 1:
         return False
     return True

@@ -2226,7 +2230,10 @@ def linear(
     nva = getnv(a, fd, lc_to_nv_map)
     nvb = getnv(b, fd, lc_to_nv_map)
     nvbias = None if bias is None else getnv(bias, fd, lc_to_nv_map)
-    return fd.ops.linear(nva, nvb, nvbias)
+
+    nva_2d = fd.ops.reshape(nva, (math.prod(a.shape[:-1]), a.shape[-1]))
+    nvc_2d = fd.ops.linear(nva_2d, nvb, nvbias)
+    return fd.ops.reshape(nvc_2d, a.shape[:-1] + (b.shape[-2],))


 register_supported(PrimIDs.LINEAR, linear, _linear_check)
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.ops.cast(T4, dtype=DataType.Float)
    T6, T7 = fd.ops.var_mean(T5, dims=[2], correction=0, keepdim=False)
    S8 = fd.define_scalar(16, dtype=DataType.Int)
    S9 = fd.define_scalar(128, dtype=DataType.Int)
    S10 = fd.define_scalar(1, dtype=DataType.Int)
    V11 = fd.define_vector([S8, S9, S10], dtype=DataType.Int)
    T12 = fd.ops.broadcast_in_dim(T6, shape=V11, broadcast_dims=[0, 1])
    S13 = fd.define_scalar(16, dtype=DataType.Int)
    S14 = fd.define_scalar(128, dtype=DataType.Int)
    S15 = fd.define_scalar(1, dtype=DataType.Int)
    V16 = fd.define_vector([S13, S14, S15], dtype=DataType.Int)
    T17 = fd.ops.broadcast_in_dim(T7, shape=V16, broadcast_dims=[0, 1])
    S18 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T19 = fd.ops.add(T12, S18)
    T20 = fd.ops.rsqrt(T19)
    S21 = fd.define_scalar(16, dtype=DataType.Int)
    S22 = fd.define_scalar(128, dtype=DataType.Int)
    S23 = fd.define_scalar(1600, dtype=DataType.Int)
    V24 = fd.define_vector([S21, S22, S23], dtype=DataType.Int)
    T25 = fd.ops.broadcast_in_dim(T17, shape=V24, broadcast_dims=[0, 1, 2])
    T26 = fd.ops.sub(T5, T25)
    S27 = fd.define_scalar(16, dtype=DataType.Int)
    S28 = fd.define_scalar(128, dtype=DataType.Int)
    S29 = fd.define_scalar(1600, dtype=DataType.Int)
    V30 = fd.define_vector([S27, S28, S29], dtype=DataType.Int)
    T31 = fd.ops.broadcast_in_dim(T20, shape=V30, broadcast_dims=[0, 1, 2])
    T32 = fd.ops.mul(T26, T31)
    S33 = fd.define_scalar(16, dtype=DataType.Int)
    S34 = fd.define_scalar(128, dtype=DataType.Int)
    S35 = fd.define_scalar(1600, dtype=DataType.Int)
    V36 = fd.define_vector([S33, S34, S35], dtype=DataType.Int)
    T37 = fd.ops.broadcast_in_dim(T3, shape=V36, broadcast_dims=[2])
    T38 = fd.ops.cast(T37, dtype=DataType.Float)
    T39 = fd.ops.mul(T32, T38)
    S40 = fd.define_scalar(16, dtype=DataType.Int)
    S41 = fd.define_scalar(128, dtype=DataType.Int)
    S42 = fd.define_scalar(1600, dtype=DataType.Int)
    V43 = fd.define_vector([S40, S41, S42], dtype=DataType.Int)
    T44 = fd.ops.broadcast_in_dim(T2, shape=V43, broadcast_dims=[2])
    T45 = fd.ops.cast(T44, dtype=DataType.Float)
    T46 = fd.ops.add(T39, T45)
    T47 = fd.ops.cast(T46, dtype=DataType.BFloat16)
    S48 = fd.define_scalar(2048, dtype=DataType.Int)
    S49 = fd.define_scalar(1600, dtype=DataType.Int)
    V50 = fd.define_vector([S48, S49], dtype=DataType.Int)
    T51 = fd.ops.reshape(T47, new_shape=V50)
    T52 = fd.ops.linear(T51, T1, T0)
    fd.add_output(T52)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((4800,), dtype=torch.bfloat16, device='cuda:0').as_strided((4800,), (1,)),
    torch.randn((7680000,), dtype=torch.bfloat16, device='cuda:0').as_strided((4800, 1600), (1600, 1)),
    torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
    torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
    torch.randn((3276800,), dtype=torch.bfloat16, device='cuda:0').as_strided((16, 128, 1600), (204800, 1600, 1)),
]
fd.execute(inputs)
Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 146, in execute
    result = self._execute(
RuntimeError: h.has_value() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/fusion_segmenter.cpp":3671, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Can not find a scheduler to schedule fusion segment
Exception raised from deriveHeuristic at /opt/pytorch/nvfuser/csrc/fusion_segmenter.cpp:3671 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7fbf362d8381 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x53 (0x7fbf365d51b3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x4bde42 (0x7fbf36675e42 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x4c5032 (0x7fbf3667d032 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x4d0c42 (0x7fbf36688c42 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: nvfuser::SegmentCandidateFinder::SegmentCandidateFinder(std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >, nvfuser::KernelArgumentHolder const*, nvfuser::SegmentCandidateFinderOptions) + 0x46f (0x7fbf366897ff in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x4a8082 (0x7fbf36660082 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x4d1a0e (0x7fbf36689a0e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: nvfuser::FusionKernelRuntime::FusionKernelRuntime(std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >, nvfuser::KernelArgumentHolder const&, nvfuser::serde::FusionKernelRuntime const*, std::optional<nvfuser::PrimDataType>, long, long, long, bool) + 0x373 (0x7fbf36799ed3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x5e5b57 (0x7fbf3679db57 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x1e7 (0x7fbf3679e8b7 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, bool, bool, std::optional<signed char>) const + 0x3c8 (0x7fbf36981998 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x18ca25 (0x7fbf36344a25 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x2009c2 (0x7fbf363b89c2 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #14: <unknown function> + 0x288d00 (0x7fbf36440d00 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #15: <unknown function> + 0x15a10e (0x555dc548510e in /usr/bin/python3)
frame #16: _PyObject_MakeTpCall + 0x25b (0x555dc547ba7b in /usr/bin/python3)
frame #17: <unknown function> + 0x168acb (0x555dc5493acb in /usr/bin/python3)
frame #18: _PyEval_EvalFrameDefault + 0x198c (0x555dc546f53c in /usr/bin/python3)
frame #19: <unknown function> + 0x16893e (0x555dc549393e in /usr/bin/python3)
frame #20: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #21: _PyObject_FastCallDictTstate + 0xc4 (0x555dc547ac14 in /usr/bin/python3)
frame #22: _PyObject_Call_Prepend + 0xc1 (0x555dc54908d1 in /usr/bin/python3)
frame #23: <unknown function> + 0x280700 (0x555dc55ab700 in /usr/bin/python3)
frame #24: _PyObject_MakeTpCall + 0x25b (0x555dc547ba7b in /usr/bin/python3)
frame #25: _PyEval_EvalFrameDefault + 0x64e6 (0x555dc5474096 in /usr/bin/python3)
frame #26: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #27: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #28: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #29: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #30: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #31: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #32: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #33: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #34: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #35: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #36: <unknown function> + 0x16893e (0x555dc549393e in /usr/bin/python3)
frame #37: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #38: <unknown function> + 0x16893e (0x555dc549393e in /usr/bin/python3)
frame #39: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #40: _PyObject_FastCallDictTstate + 0xc4 (0x555dc547ac14 in /usr/bin/python3)
frame #41: _PyObject_Call_Prepend + 0x5c (0x555dc549086c in /usr/bin/python3)
frame #42: <unknown function> + 0x280700 (0x555dc55ab700 in /usr/bin/python3)
frame #43: PyObject_Call + 0xbb (0x555dc549442b in /usr/bin/python3)
frame #44: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #45: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #46: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #47: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #48: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #49: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #50: _PyEval_EvalFrameDefault + 0x6bd (0x555dc546e26d in /usr/bin/python3)
frame #51: <unknown function> + 0x1687f1 (0x555dc54937f1 in /usr/bin/python3)
frame #52: _PyEval_EvalFrameDefault + 0x198c (0x555dc546f53c in /usr/bin/python3)
frame #53: <unknown function> + 0x1687f1 (0x555dc54937f1 in /usr/bin/python3)
frame #54: _PyEval_EvalFrameDefault + 0x198c (0x555dc546f53c in /usr/bin/python3)
frame #55: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #56: PyObject_Call + 0x122 (0x555dc5494492 in /usr/bin/python3)
frame #57: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #58: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #59: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #60: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #61: _PyEval_EvalFrameDefault + 0x614a (0x555dc5473cfa in /usr/bin/python3)
frame #62: <unknown function> + 0x1687f1 (0x555dc54937f1 in /usr/bin/python3)
frame #63: _PyEval_EvalFrameDefault + 0x614a (0x555dc5473cfa in /usr/bin/python3)

from fuser.

wujingyue avatar wujingyue commented on August 21, 2024

FYI, NVFUSER_DUMP=segmenter_logging prints the following

**Segmenter** Considering fusion:
T34_g[ iS97{( 16 * 128 )}, iS138{i0}, rS99{1600} ]
   = mma(T32_g[ iS91{( 16 * 128 )}, bS92{1}, iS93{1600} ],
         T33_g[ bS94{1}, iS136{i0}, iS137{1600} ])

Scheduler _no_op_ ***rejected*** because : output has a concrete dimension
Scheduler _matmul_ ***rejected*** because : MmaOp input has unsupported dependency
Scheduler _reduction_ ***rejected*** because : No reduction op to schedule
Scheduler _transpose_ ***rejected*** because : no support for mma ops.
Scheduler _pointwise_ ***rejected*** because : no support for mma ops.
Scheduler _inner_persistent_ ***rejected*** because : needs a reduction op
Scheduler _outer_persistent_ ***rejected*** because : needs a reduction op
Scheduler _inner_outer_persistent_ ***rejected*** because : needs a reduction op

from fuser.

wujingyue avatar wujingyue commented on August 21, 2024

The matmul scheduler failed at

const auto areMmaOpInputDependeciesValid = [](const Val* val) {

Looks like it assumes both operands to be broadcasted. I'm under the impression that we removed that assumption for #1628. What am I missing? @zasdfgbnm

from fuser.

wujingyue avatar wujingyue commented on August 21, 2024

FYI, below is the complete fusion after preseg optimizations. The MmaOp is indeed part of the beautiful broadcast+broadcast+mma+add+float2bfloat subgraph, which is good. However, due to other ops in the fusion, this subgraph is not given to the matmul scheduler immediately. Instead, it's decomposed into singletons, and the segmenter has troubles merging them into the expected subgraph.

$ NVFUSER_DUMP=fusion_ir_preseg python repro.py 
Fusion IR after pre-segmenter optimization passes:
Inputs:
  T0_g[ iS0{i0} ], __bfloat
  T1_g[ iS134{i0}, iS135{1600} ], __bfloat
  T2_g[ iS132{1600} ], __bfloat
  T3_g[ iS130{1600} ], __bfloat
  T4_g[ iS107{16}, iS108{128}, iS109{1600} ], __bfloat
Outputs:
  T38_g[ iS105{2048}, iS140{i0} ], __bfloat

%kernel_math {
T5_l[ iS110{16}, iS111{128}, iS112{1600} ]
   = __bfloat2float(T4_g[ iS107{16}, iS108{128}, iS109{1600} ]);
T6_l[ iS116{16}, iS117{128}, rS118{1600} ](Avg),
T7_l[ iS122{16}, iS123{128}, rS124{1600} ](Var),
T8_l[ iS113{16}, iS114{128}, rS115{1600} ](Count)
 = Welford ( T5_l[ iS110{16}, iS111{128}, iS112{1600} ](Avg), 
  allreduce = false )
T12_l[ iS119{16}, iS120{128}, bS30{1} ]
   = broadcast( T6_l[ iS116{16}, iS117{128}, rS118{1600} ] )
T13_l[ iS31{16}, iS32{128}, bS33{1} ]
   = Set( T12_l[ iS119{16}, iS120{128}, bS30{1} ], cache_op=Streaming )
T16_l[ iS40{16}, iS41{128}, bS42{1} ]
   = Set( T13_l[ iS31{16}, iS32{128}, bS33{1} ], cache_op=Streaming )
T17_l[ iS43{16}, iS44{128}, bS45{1 ex 1600} ] = expand( T16_l[ iS40{16}, iS41{128}, bS42{1} ], {16, 128, 1600} )
T18_l[ iS46{16}, iS47{128}, iS121{1600} ]
   = T5_l[ iS110{16}, iS111{128}, iS112{1600} ]
   - T17_l[ iS43{16}, iS44{128}, bS45{1 ex 1600} ];
d17 = (double)(1600);
d19 = double(1) * d17;
d23 = (double)(0);
d25 = d19 - d23;
d27 = (double)(0);
b29 = d25 >= d27;
d31 = (double)(0);
d33 = where(b29, d25, d31);
d39 = reciprocal(d33);
T9_l[ iS125{16}, iS126{128} ]
   = T7_l[ iS122{16}, iS123{128}, rS124{1600} ]
   * d39;
T10_l[ iS127{16}, iS128{128}, bS24{1} ]
   = broadcast( T9_l[ iS125{16}, iS126{128} ] )
T11_l[ iS25{16}, iS26{128}, bS27{1} ]
   = Set( T10_l[ iS127{16}, iS128{128}, bS24{1} ], cache_op=Streaming )
T14_l[ iS34{16}, iS35{128}, bS36{1} ]
   = T11_l[ iS25{16}, iS26{128}, bS27{1} ]
   + double(1.0000000000000001e-05);
T15_l[ iS37{16}, iS38{128}, bS39{1} ]
   = rsqrtf(T14_l[ iS34{16}, iS35{128}, bS36{1} ]);
T19_l[ iS49{16}, iS50{128}, bS51{1} ]
   = Set( T15_l[ iS37{16}, iS38{128}, bS39{1} ], cache_op=Streaming )
T20_l[ iS52{16}, iS53{128}, bS54{1 ex 1600} ] = expand( T19_l[ iS49{16}, iS50{128}, bS51{1} ], {16, 128, 1600} )
T21_l[ iS55{16}, iS56{128}, iS129{1600} ]
   = T18_l[ iS46{16}, iS47{128}, iS121{1600} ]
   * T20_l[ iS52{16}, iS53{128}, bS54{1 ex 1600} ];
T22_l[ bS58{1}, bS59{1}, iS131{1600} ]
   = broadcast( T3_g[ iS130{1600} ] )
T23_l[ bS61{1 ex 16}, bS62{1 ex 128}, iS63{1600} ] = expand( T22_l[ bS58{1}, bS59{1}, iS131{1600} ], {16, 128, 1600} )
T24_l[ bS64{1 ex 16}, bS65{1 ex 128}, iS66{1600} ]
   = __bfloat2float(T23_l[ bS61{1 ex 16}, bS62{1 ex 128}, iS63{1600} ]);
T25_l[ iS67{16}, iS68{128}, iS69{1600} ]
   = T21_l[ iS55{16}, iS56{128}, iS129{1600} ]
   * T24_l[ bS64{1 ex 16}, bS65{1 ex 128}, iS66{1600} ];
T26_l[ bS70{1}, bS71{1}, iS133{1600} ]
   = broadcast( T2_g[ iS132{1600} ] )
T27_l[ bS73{1 ex 16}, bS74{1 ex 128}, iS75{1600} ] = expand( T26_l[ bS70{1}, bS71{1}, iS133{1600} ], {16, 128, 1600} )
T28_l[ bS76{1 ex 16}, bS77{1 ex 128}, iS78{1600} ]
   = __bfloat2float(T27_l[ bS73{1 ex 16}, bS74{1 ex 128}, iS75{1600} ]);
T29_l[ iS79{16}, iS80{128}, iS81{1600} ]
   = T25_l[ iS67{16}, iS68{128}, iS69{1600} ]
   + T28_l[ bS76{1 ex 16}, bS77{1 ex 128}, iS78{1600} ];
T30_l[ iS82{16}, iS83{128}, iS84{1600} ]
   = __float2bfloat(T29_l[ iS79{16}, iS80{128}, iS81{1600} ]);
T31_l[ iS90{( 16 * 128 )}rf, iS87{1600} ] = view( T30_l[ iS82{16}, iS83{128}, iS84{1600} ] )
T32_l[ iS91{( 16 * 128 )}, bS92{1}, iS93{1600} ]
   = broadcast( T31_l[ iS90{( 16 * 128 )}rf, iS87{1600} ] )
T33_l[ bS94{1}, iS136{i0}, iS137{1600} ]
   = broadcast( T1_g[ iS134{i0}, iS135{1600} ] )
T34_l[ iS97{( 16 * 128 )}, iS138{i0}, rS99{1600} ]
   = mma(T32_l[ iS91{( 16 * 128 )}, bS92{1}, iS93{1600} ],
         T33_l[ bS94{1}, iS136{i0}, iS137{1600} ])
T35_l[ iS100{i0} ]
   = __bfloat2float(T0_g[ iS0{i0} ]);
T36_l[ bS101{1}, iS102{i0} ]
   = broadcast( T35_l[ iS100{i0} ] )
T37_l[ iS103{2048}, iS139{i0} ]
   = T34_l[ iS97{( 16 * 128 )}, iS138{i0}, rS99{1600} ]
   + T36_l[ bS101{1}, iS102{i0} ];
T38_g[ iS105{2048}, iS140{i0} ]
   = __float2bfloat(T37_l[ iS103{2048}, iS139{i0} ]);
}

from fuser.

Priya2698 avatar Priya2698 commented on August 21, 2024

This issue looks related to: #2127.
The failure stemmed from assuming inputs to be created through BroadcastOp.

@wujingyue What do you get after #2221?

While the ATen evaluation for matmul/linear will drop these assumptions once the new IR nodes are merged, at present, we assume the same in pattern matching as well (BroadcastOp -> MmaOp-> CastOp)

from fuser.

Priya2698 avatar Priya2698 commented on August 21, 2024

at present, we assume the same in pattern matching as well (BroadcastOp -> MmaOp-> CastOp)

That's right. I already merged #2221, so you can reproduce this by running the reproducer in #2199 (comment). Anyhow, I'll run my experiments with matmul_expr_eval disabled, so #2221 is sufficient to unblock me at this moment.

That being said, there's a variation of the same problem for the ATen evaluation: the segmenter doesn't guarantee to put cast into the same segment, just as it didn't put broadcast that way. I think the new IR nodes will help that but I'm not sure and I'll leave that to you.

Yes, the new IR nodes will fix this issue since we won't evaluate a decomposed IR. The pattern matching will be redundant and removed once the API is modified to use the new IR nodes.

from fuser.

wujingyue avatar wujingyue commented on August 21, 2024

@Priya2698 Wdyt about the drafted WAR for nvFuser-not-support-3D-linear? Submit that WAR in Thunder or wait for your PRs?

from fuser.

Priya2698 avatar Priya2698 commented on August 21, 2024

@Priya2698 Wdyt about the drafted WAR for nvFuser-not-support-3D-linear? Submit that WAR in Thunder or wait for your PRs?

It looks like the WAR will still run into the segmentation issue due to the reshapes.

If you don't necessarily need that change in thunder to proceed, then adding the new nodes will lift that restriction anyway. I am estimating the new PRs within a couple days earlier next week.

We can go ahead with it if it unblocks you in the interim.

from fuser.

wujingyue avatar wujingyue commented on August 21, 2024

Cool -- I closed Lightning-AI/lightning-thunder#391.

from fuser.

cowanmeg avatar cowanmeg commented on August 21, 2024

Thanks @wujingyue! This is super helpful, I'll start working on the sharding soon!

from fuser.

cowanmeg avatar cowanmeg commented on August 21, 2024

I annotated the sharding of the MLP layer of the example: https://gist.github.com/cowanmeg/75b4144a3627df74efcfc12dda01a2a3

Some comments:
(1) The two linear layers and GeLU have sharded computation. The dropout, layernorm, and residual add have replicated computed on each device. (BTW I don't think it would be too hard to represent SP).
Sharding propagation is relatively straightforward if we annotate only the Linear layer inputs and outputs. I think the current naive one will suffice for at least now.
(2) Now that LinearOp and MatmulOp are part of the compute definition, we need should reconsider how we insert resharding expressions and DID leaf parallelization. (cc @Priya2698 @jacobhinkle)
(3) Pointwise scheduler needs to be modified to ignore DID axes. This should be straightforward as reordering DID axes in front and ignoring them.

While we discuss our design for (2), I will manually translate these programs and decompose the LinearOp myself. Regardless this is necessary since we need to logically split sharded axes in the compute definition because of our RFactor restriction. For MLP, this isn't too hard and would let us get a small example working.

from fuser.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.