Giter Club home page Giter Club logo

Comments (8)

zasdfgbnm avatar zasdfgbnm commented on August 21, 2024 2

Thanks @jjsjann123 for providing this repro. I do believe there are some opportunities to simplify symbolicly here:

For example, the predicate of T3 = pad(T2) looks like:

Static shape:

-8 + (4 * (threadIdx.x % 4)) < -i0

Dynamic shape:

  ((((4 * threadIdx.x) + (512 * blockIdx.x)) + i0) %
   ((8 * T1.logical_size[1LL]) +
    (T1.logical_size[1LL] * T2.logical_size[2LL]))) %
          (8 + T2.logical_size[2LL]) <
      T2.logical_size[2LL]

For the dynamic shape case, note that let:

a = ((4 * threadIdx.x) + (512 * blockIdx.x)) + i0;
b = T1.logical_size[1LL];
c = 8 + T2.logical_size[2LL];

then the predicate is:

a % (b * c) % c < T2.logical_size[2LL]

which clearly can be simplified as:

a % c = (((4 * threadIdx.x) + (512 * blockIdx.x)) + i0) % (8 + T2.logical_size[2LL]) < T2.logical_size[2LL]

Clearly not as good as the static shape case, but still an improve.

Kernel diff: https://www.diffchecker.com/vK2pS9ak/

from fuser.

jjsjann123 avatar jjsjann123 commented on August 21, 2024

cc'ing @zasdfgbnm , I don't think there's any actionable item needed on your side at this moment. I'll update this after I checked the performance with the new definition.

from fuser.

jacobhinkle avatar jacobhinkle commented on August 21, 2024

Here is a diff of the generated pointwise kernels on my 3090Ti:

--- static.cu   2024-08-19 10:23:21.977784983 -0400
+++ dynamic.cu  2024-08-19 10:23:58.741144923 -0400
@@ -10697,68 +10697,96 @@
 }
 
 } // namespace fused_reduction
-__global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, Tensor<__bfloat, 4, 4> T48) {
+__global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, nvfuser_index_t i0, nvfuser_index_t i1, nvfuser_index_t i2, Tensor<__bfloat, 4, 4> T48) {
   NVFUSER_DEFINE_MAGIC_ZERO;
-  nvfuser_index_t i0;
-  i0 = ((nvfuser_index_t)threadIdx.x) + (((nvfuser_index_t)blockDim.x) * ((nvfuser_index_t)blockIdx.y));
-  nvfuser_index_t i1;
-  i1 = 8 * (i0 % 4);
-  nvfuser_index_t i2;
-  i2 = i0 / 4;
   nvfuser_index_t i3;
-  i3 = -4 + i1;
+  i3 = 8 * ((nvfuser_index_t)threadIdx.x);
   nvfuser_index_t i4;
-  i4 = i3 + (T26.alloc_stride[0LL] * i2);
+  i4 = (8 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.y);
   nvfuser_index_t i5;
-  i5 = ((-4 + ((1024 * T6.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T6.alloc_stride[2LL] * i2);
+  i5 = i3 + i4;
   nvfuser_index_t i6;
-  i6 = i1 + (T14.alloc_stride[0LL] * i2);
+  i6 = 28 + T26.logical_size[1LL];
   nvfuser_index_t i7;
-  i7 = (((1024 * T10.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x)) + i1) + (T10.alloc_stride[2LL] * i2);
+  i7 = ((nvfuser_index_t)blockIdx.x) / T10.logical_size[1LL];
   nvfuser_index_t i8;
-  i8 = (((1024 * T8.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x)) + i1) + (T8.alloc_stride[2LL] * i2);
+  i8 = ((nvfuser_index_t)blockIdx.x) % T10.logical_size[1LL];
   nvfuser_index_t i9;
-  i9 = i1 + (T22.alloc_stride[0LL] * i2);
+  i9 = (-4 + (T6.alloc_stride[0LL] * i7)) + (T6.alloc_stride[1LL] * i8);
   nvfuser_index_t i10;
-  i10 = ((-8 + ((1024 * T4.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T4.alloc_stride[2LL] * i2);
+  i10 = 28 + T6.logical_size[3LL];
   nvfuser_index_t i11;
-  i11 = i3 + (T18.alloc_stride[0LL] * i2);
+  i11 = 28 + T14.logical_size[1LL];
   nvfuser_index_t i12;
-  i12 = ((-4 + ((1024 * T12.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T12.alloc_stride[2LL] * i2);
+  i12 = (T10.alloc_stride[0LL] * i7) + (T10.alloc_stride[1LL] * i8);
   nvfuser_index_t i13;
-  i13 = 8 * ((nvfuser_index_t)threadIdx.x);
+  i13 = 28 + T10.logical_size[3LL];
   nvfuser_index_t i14;
-  i14 = (8 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.y);
+  i14 = (T8.alloc_stride[0LL] * i7) + (T8.alloc_stride[1LL] * i8);
   nvfuser_index_t i15;
-  i15 = (i13 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i14;
-  bool b16;
-  b16 = (i13 + i14) < 32768;
-  if ((((i13 + 7) + i14) < 32768)) {
-    Array<__bfloat, 8, 8> T50;
+  i15 = 28 + T8.logical_size[3LL];
+  nvfuser_index_t i16;
+  i16 = 28 + T22.logical_size[1LL];
+  nvfuser_index_t i17;
+  i17 = (-8 + (T4.alloc_stride[0LL] * i7)) + (T4.alloc_stride[1LL] * i8);
+  nvfuser_index_t i18;
+  i18 = 8 + T4.logical_size[3LL];
+  nvfuser_index_t i19;
+  i19 = 28 + T18.logical_size[1LL];
+  nvfuser_index_t i20;
+  i20 = (-4 + (T12.alloc_stride[0LL] * i7)) + (T12.alloc_stride[1LL] * i8);
+  nvfuser_index_t i21;
+  i21 = 28 + T12.logical_size[3LL];
+  nvfuser_index_t i22;
+  i22 = (i3 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i4;
+  nvfuser_index_t i23;
+  i23 = 24 * T10.logical_size[2LL];
+  nvfuser_index_t i24;
+  i24 = ((max(4, (min(i0, 8)))) * T10.logical_size[2LL]) + i23;
+  bool b25;
+  b25 = i5 < i24;
+  nvfuser_index_t i26;
+  i26 = 28 * T10.logical_size[2LL];
+  bool b27;
+  b27 = i5 < (i26 + (T10.logical_size[2LL] * T6.logical_size[3LL]));
+  bool b28;
+  b28 = i5 < (i26 + (T10.logical_size[2LL] * T14.logical_size[1LL]));
+  bool b29;
+  b29 = i5 < ((T10.logical_size[2LL] * T10.logical_size[3LL]) + i26);
+  bool b30;
+  b30 = i5 < (((max(4, (min((max(0LL, (min(i1, 8)))), 8)))) * T10.logical_size[2LL]) + i23);
+  bool b31;
+  b31 = i5 < (i26 + (T10.logical_size[2LL] * T22.logical_size[1LL]));
+  bool b32;
+  b32 = i5 < ((max(8, (min(i1, 32)))) * T10.logical_size[2LL]);
+  bool b33;
+  b33 = i5 < (((max(4, (min(i2, 8)))) * T10.logical_size[2LL]) + i23);
+  if ((((i3 + 7) + i4) < i24)) {
+    Array<__bfloat, 8, 8> T54;
     #pragma unroll
-    for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
-      nvfuser_index_t i18;
-      i18 = i17 + nvfuser_zero;
+    for(nvfuser_index_t i34 = 0; i34 < 8; ++i34) {
+      nvfuser_index_t i35;
+      i35 = i5 + (i34 + nvfuser_zero);
       __bfloat T27[1];
       T27[0] = 0;
       T27[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T26[(i4 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i35 / i6))) + (i35 % i6))] : 0.0000e+00;
       __bfloat T28[1];
       T28[0]
          = T27[0];
-      __bfloat T29[1];
-      T29[0]
+      __bfloat T52[1];
+      T52[0]
          = T28[0];
       __bfloat T7[1];
       T7[0] = 0;
       T7[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T6[(i5 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i35 / i10))) + (i35 % i10))] : 0.0000e+00;
       float T38[1];
       T38[0]
          = __bfloat2float(T7[0]);
       float T39[1];
       T39[0]
-         = __bfloat2float(T29[0]);
+         = __bfloat2float(T52[0]);
       float T40[1];
       T40[0]
         = T38[0]
@@ -10766,23 +10794,23 @@
       __bfloat T15[1];
       T15[0] = 0;
       T15[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T14[(i6 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) < T14.logical_size[1LL])) ? T14[((T14.alloc_stride[0LL] * (i35 / i11)) + (i35 % i11))] : 0.0000e+00;
       __bfloat T16[1];
       T16[0]
          = T15[0];
-      __bfloat T17[1];
-      T17[0]
+      __bfloat T51[1];
+      T51[0]
          = T16[0];
       __bfloat T11[1];
       T11[0] = 0;
       T11[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T10[(i7 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) < T10.logical_size[3LL])) ? T10[((i12 + (T10.alloc_stride[2LL] * (i35 / i13))) + (i35 % i13))] : 0.0000e+00;
       float T35[1];
       T35[0]
          = __bfloat2float(T11[0]);
       float T34[1];
       T34[0]
-         = __bfloat2float(T17[0]);
+         = __bfloat2float(T51[0]);
       float T36[1];
       T36[0]
         = T34[0]
@@ -10790,7 +10818,7 @@
       __bfloat T9[1];
       T9[0] = 0;
       T9[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T8[(i8 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) < T8.logical_size[3LL])) ? T8[((i14 + (T8.alloc_stride[2LL] * (i35 / i15))) + (i35 % i15))] : 0.0000e+00;
       float T30[1];
       T30[0]
          = __bfloat2float(T9[0]);
@@ -10800,16 +10828,16 @@
       __bfloat T23[1];
       T23[0] = 0;
       T23[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T22[(i9 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) < T22.logical_size[1LL])) ? T22[((T22.alloc_stride[0LL] * (i35 / i16)) + (i35 % i16))] : 0.0000e+00;
       __bfloat T24[1];
       T24[0]
          = T23[0];
-      __bfloat T25[1];
-      T25[0]
+      __bfloat T50[1];
+      T50[0]
          = T24[0];
       float T32[1];
       T32[0]
-         = __bfloat2float(T25[0]);
+         = __bfloat2float(T50[0]);
       float T33[1];
       T33[0]
         = T31[0]
@@ -10817,30 +10845,30 @@
       __bfloat T5[1];
       T5[0] = 0;
       T5[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) < 24)) ? T4[(i10 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i35 / i18))) + (i35 % i18))] : 0.0000e+00;
       float T46[1];
       T46[0]
          = __bfloat2float(T5[0]);
       __bfloat T19[1];
       T19[0] = 0;
       T19[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T18[(i11 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i35 / i19))) + (i35 % i19))] : 0.0000e+00;
       __bfloat T20[1];
       T20[0]
          = T19[0];
-      __bfloat T21[1];
-      T21[0]
+      __bfloat T53[1];
+      T53[0]
          = T20[0];
       __bfloat T13[1];
       T13[0] = 0;
       T13[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T12[(i12 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i35 / i21))) + (i35 % i21))] : 0.0000e+00;
       float T42[1];
       T42[0]
          = __bfloat2float(T13[0]);
       float T41[1];
       T41[0]
-         = __bfloat2float(T21[0]);
+         = __bfloat2float(T53[0]);
       float T43[1];
       T43[0]
         = T41[0]
@@ -10861,78 +10889,78 @@
       T47[0]
         = T45[0]
         + T46[0];
-      T50[i17]
+      T54[i34]
          = __float2bfloat(T47[0]);
     }
     NVFUSER_UPDATE_MAGIC_ZERO;
-    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i15], &T50[0]);
+    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
   } else {
-    Array<__bfloat, 8, 8> T50;
+    Array<__bfloat, 8, 8> T54;
     #pragma unroll
-    for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
-      nvfuser_index_t i19;
-      i19 = i17 + nvfuser_zero;
+    for(nvfuser_index_t i34 = 0; i34 < 8; ++i34) {
+      nvfuser_index_t i36;
+      i36 = i5 + (i34 + nvfuser_zero);
       __bfloat T27[1];
       T27[0] = 0;
-      if (b16) {
+      if (b25) {
         T27[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T26[(i4 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i36 / i6))) + (i36 % i6))] : 0.0000e+00;
       }
       __bfloat T28[1];
       T28[0]
          = T27[0];
-      __bfloat T29[1];
-      T29[0]
+      __bfloat T52[1];
+      T52[0]
          = T28[0];
       __bfloat T7[1];
       T7[0] = 0;
-      if (b16) {
+      if (b27) {
         T7[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T6[(i5 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i36 / i10))) + (i36 % i10))] : 0.0000e+00;
       }
       float T38[1];
       T38[0]
          = __bfloat2float(T7[0]);
       float T39[1];
       T39[0]
-         = __bfloat2float(T29[0]);
+         = __bfloat2float(T52[0]);
       float T40[1];
       T40[0]
         = T38[0]
         * T39[0];
       __bfloat T15[1];
       T15[0] = 0;
-      if (b16) {
+      if (b28) {
         T15[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T14[(i6 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) < T14.logical_size[1LL])) ? T14[((T14.alloc_stride[0LL] * (i36 / i11)) + (i36 % i11))] : 0.0000e+00;
       }
       __bfloat T16[1];
       T16[0]
          = T15[0];
-      __bfloat T17[1];
-      T17[0]
+      __bfloat T51[1];
+      T51[0]
          = T16[0];
       __bfloat T11[1];
       T11[0] = 0;
-      if (b16) {
+      if (b29) {
         T11[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T10[(i7 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) < T10.logical_size[3LL])) ? T10[((i12 + (T10.alloc_stride[2LL] * (i36 / i13))) + (i36 % i13))] : 0.0000e+00;
       }
       float T35[1];
       T35[0]
          = __bfloat2float(T11[0]);
       float T34[1];
       T34[0]
-         = __bfloat2float(T17[0]);
+         = __bfloat2float(T51[0]);
       float T36[1];
       T36[0]
         = T34[0]
         * T35[0];
       __bfloat T9[1];
       T9[0] = 0;
-      if (b16) {
+      if (b30) {
         T9[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T8[(i8 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) < T8.logical_size[3LL])) ? T8[((i14 + (T8.alloc_stride[2LL] * (i36 / i15))) + (i36 % i15))] : 0.0000e+00;
       }
       float T30[1];
       T30[0]
@@ -10942,56 +10970,56 @@
          = -T30[0];
       __bfloat T23[1];
       T23[0] = 0;
-      if (b16) {
+      if (b31) {
         T23[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T22[(i9 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) < T22.logical_size[1LL])) ? T22[((T22.alloc_stride[0LL] * (i36 / i16)) + (i36 % i16))] : 0.0000e+00;
       }
       __bfloat T24[1];
       T24[0]
          = T23[0];
-      __bfloat T25[1];
-      T25[0]
+      __bfloat T50[1];
+      T50[0]
          = T24[0];
       float T32[1];
       T32[0]
-         = __bfloat2float(T25[0]);
+         = __bfloat2float(T50[0]);
       float T33[1];
       T33[0]
         = T31[0]
         * T32[0];
       __bfloat T5[1];
       T5[0] = 0;
-      if (b16) {
+      if (b32) {
         T5[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) < 24)) ? T4[(i10 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i36 / i18))) + (i36 % i18))] : 0.0000e+00;
       }
       float T46[1];
       T46[0]
          = __bfloat2float(T5[0]);
       __bfloat T19[1];
       T19[0] = 0;
-      if (b16) {
+      if (b33) {
         T19[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T18[(i11 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i36 / i19))) + (i36 % i19))] : 0.0000e+00;
       }
       __bfloat T20[1];
       T20[0]
          = T19[0];
-      __bfloat T21[1];
-      T21[0]
+      __bfloat T53[1];
+      T53[0]
          = T20[0];
       __bfloat T13[1];
       T13[0] = 0;
-      if (b16) {
+      if (b30) {
         T13[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T12[(i12 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i36 / i21))) + (i36 % i21))] : 0.0000e+00;
       }
       float T42[1];
       T42[0]
          = __bfloat2float(T13[0]);
       float T41[1];
       T41[0]
-         = __bfloat2float(T21[0]);
+         = __bfloat2float(T53[0]);
       float T43[1];
       T43[0]
         = T41[0]
@@ -11012,12 +11040,12 @@
       T47[0]
         = T45[0]
         + T46[0];
-      T50[i17]
+      T54[i34]
          = __float2bfloat(T47[0]);
     }
     NVFUSER_UPDATE_MAGIC_ZERO;
-    if (b16) {
-      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i15], &T50[0]);
+    if ((i5 < 32768)) {
+      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
     }
   }
 }

The static kernel is just using the commented lines in the repro posted above. It achieves about 5x higher BW compared to dynamic (runs in 8 us vs 38).

There is more in the preamble for dynamic shapes, but inside the loops the expressions also have slightly more going on. For example (zoomed in and inserted line breaks):

-- static.cu   2024-08-19 10:23:21.977784983 -0400
+++ dynamic.cu  2024-08-19 10:23:58.741144923 -0400
@@ -10697,68 +10697,96 @@
       __bfloat T27[1];
       T27[0] = 0;
       T27[0] = ((((((((((nvfuser_index_t)blockIdx.y) *
                        ((nvfuser_index_t)blockDim.x)) +
                       ((nvfuser_index_t)threadIdx.x)) *
                      8) +
-                    (i17 + nvfuser_zero)) %
-                   32) -
+                    (i34 + nvfuser_zero)) %
+                   ((T26.logical_size[1LL] + 4) + 24)) -
                   4) >= 0) &&
                 (((((((((nvfuser_index_t)blockIdx.y) *
                        ((nvfuser_index_t)blockDim.x)) +
                       ((nvfuser_index_t)threadIdx.x)) *
                      8) +
-                    (i17 + nvfuser_zero)) %
-                   32) -
-                  4) < 4))
-          ? T26[(i4 + i18)]
+                    (i34 + nvfuser_zero)) %
+                   ((T26.logical_size[1LL] + 4) + 24)) -
+                  4) < T26.logical_size[1LL]))
+          ? T26[((-4 + (T26.alloc_stride[0LL] * (i35 / i6))) + (i35 % i6))]
           : 0.0000e+00;
       __bfloat T28[1];
       T28[0] = T27[0];

In this context i35 is a loop index, so we might not be able to simplify the last diff line much, but we also are not hoisting (T26.logical_size[1LL] + 4) + 24) for some reason...

from fuser.

jacobhinkle avatar jacobhinkle commented on August 21, 2024

As for the preamble, there are lots of max and mins in the dynamic kernel, which could be avoided using #511 (I'm looking at updating this). As discussed last week, we could temporarily make all sliced input extents and all slice ranges constant at concretization, which I think would give us a kernel similar to static.cu above.

from fuser.

jjsjann123 avatar jjsjann123 commented on August 21, 2024

For our own sanity, here's a simplified cpp test. Indexing isn't being simplified even when the slice is passing in the correct extent val.

For @jacobhinkle 's WAR in #511. I understand it as that, we wouldn't need this form of definition and the performance in the original python repro shouldn't regress with dynamic shape.

Creating this repro for @zasdfgbnm , I'm assuming the definition here should be enough to tell us that we are doing two non-overlapping slice and indexing maybe could be simplified, even without concretization...

TEST_F(NVFuserTest, DynamicShapedPad) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  auto zero = fusion->zeroVal();
  auto one = fusion->oneVal();

  std::vector<int64_t> shape{32, 1024, 16};

#if 0
  auto tv0 = makeContigConcreteTensor(shape);
  auto dim0 = IrBuilder::create<Val>(32, DataType::Index);
  auto dim1 = IrBuilder::create<Val>(1024, DataType::Index);
  auto dim2 = IrBuilder::create<Val>(16, DataType::Index);
  auto val_slice = IrBuilder::create<Val>(8, DataType::Index);
  auto val_remain = IrBuilder::create<Val>(8, DataType::Index);
#else
  auto tv0 = makeContigTensor(3);
  auto dim0 = tv0->axis(0)->extent();
  auto dim1 = tv0->axis(1)->extent();
  auto dim2 = tv0->axis(2)->extent();
  auto val_slice = IrBuilder::create<Val>(8, DataType::Index);
  auto val_remain = sub(dim2, val_slice);
#endif
  Slice slice_dim_0{zero, dim0, one};
  Slice slice_dim_1{zero, dim1, one};
  Slice slice_dim_2_l{zero, val_slice, one};
  Slice slice_dim_2_r{val_slice, dim2, one};

  std::vector<Slice> slice_l_ind = {slice_dim_0, slice_dim_1, slice_dim_2_l};
  std::vector<Slice> slice_r_ind = {slice_dim_0, slice_dim_1, slice_dim_2_r};

  auto slice_l = slice(tv0, slice_l_ind);
  auto slice_r = slice(tv0, slice_r_ind);

  fusion->addInput(tv0);

  auto rope_l = pad(slice_r, {zero, val_slice, zero, zero, zero, zero});
  // avoid segmentation.
  // auto rope_r = pad(neg(slice_l), {val_remain, zero, zero, zero, zero, zero});
  auto rope_r = neg(pad(slice_l, {val_remain, zero, zero, zero, zero, zero}));

  auto o = add(rope_l, rope_r);

  fusion->addOutput(o);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn(shape, options);
  std::vector<c10::IValue> aten_inputs({t0});

  FusionExecutorCache executor_cache(std::move(fusion));
  auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

  testValidate(
      executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__);
}

from fuser.

jjsjann123 avatar jjsjann123 commented on August 21, 2024

Yeah if there's no low-hanging fruits, I don't think it matters at this point, since we are going down the path with @jacobhinkle 's plan on static shapes during concretization.

We can revisit this if we decide to push it further afterwards.

from fuser.

jacobhinkle avatar jacobhinkle commented on August 21, 2024

static shapes during concretization.

BTW in implementing this I just noticed that a lot of the resizes are dynamic but for the provided inputs are actually trivial:

    ?S9{( fmax(0, ( fmin(i0, 2) )) )}rf (index=0) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S11{( fmax(0, ( fmin(i1, 16) )) )}rf (index=1) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S13{( fmax(0, ( fmin(i2, 1024) )) )}rf (index=2) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S15{( fmax(0, ( fmin(i3, 8) )) )}rf (index=3) is a resize of input extent 32 with left_pad=0 and right_pad=-24
    ?S43{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=4) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S45{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=5) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S47{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=6) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S49{( ( fmax(4, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 8) )) ) - 4 )}rf (index=7) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S112{( fmax(0, ( fmin(i7, 1024) )) )}rf (index=8) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S114{( fmax(0, ( fmin(i8, 4) )) )}rf (index=9) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S82{( fmax(0, ( fmin(i5, 1024) )) )}rf (index=10) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S84{( fmax(0, ( fmin(i6, 4) )) )}rf (index=11) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S56{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=12) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S58{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=13) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S60{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=14) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S62{( fmax(0, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 4) )) )}rf (index=15) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S30{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=16) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S32{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=17) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S34{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=18) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S36{( fmax(0, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 4) )) )}rf (index=19) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S127{( fmax(0, ( fmin(i7, 1024) )) )}rf (index=20) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S129{( ( fmax(4, ( fmin(i8, 8) )) ) - 4 )}rf (index=21) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S97{( fmax(0, ( fmin(i5, 1024) )) )}rf (index=22) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S99{( ( fmax(4, ( fmin(i6, 8) )) ) - 4 )}rf (index=23) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S69{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=24) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S71{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=25) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S73{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=26) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S75{( ( fmax(4, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 8) )) ) - 4 )}rf (index=27) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S17{( fmax(0, ( fmin(i0, 2) )) )}rf (index=28) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S19{( fmax(0, ( fmin(i1, 16) )) )}rf (index=29) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S21{( fmax(0, ( fmin(i2, 1024) )) )}rf (index=30) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S23{( ( fmax(8, ( fmin(i3, 32) )) ) - 8 )}rf (index=31) is a resize of input extent 32 with left_pad=-8 and right_pad=0

By my count 21 out of these 32 resized axes are not actually resized at all. Using static shapes, not only will the expressions be simpler, but we will catch every one of these trivial resizes and we will not predicate that access. I'll have that as part of concretization in a PR soon.

from fuser.

jjsjann123 avatar jjsjann123 commented on August 21, 2024

By my count 21 out of these 32 resized axes are not actually resized at all. Using static shapes, not only will the expressions be simpler, but we will catch every one of these trivial resizes and we will not predicate that access. I'll have that as part of concretization in a PR soon.

Thanks for pointing out that. Yes that's expected, this is one of the mismatch on thunder's static program. slice with [..., ] is baked in as constants as well... We need to re-write those logic later.

from fuser.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.