____________________________________________________________________________ test_benchmark_masked[triton-bias-mask-causal-bf16-8x128] ____________________________________________________________________________
heads = 48, seq_length = 128
Q = tensor([[[[3.9844e-01, 5.1562e-01, 2.4902e-02, ..., 6.7969e-01,
4.4141e-01, 5.0000e-01],
[8.0859...6.7188e-01, 1.5918e-01, ..., 3.0859e-01,
7.9688e-01, 5.4297e-01]]]], device='cuda:0', dtype=torch.bfloat16)
K = tensor([[[[3.2422e-01, 9.5703e-01, 3.1055e-01, ..., 5.3516e-01,
4.9023e-01, 1.9531e-01],
[7.7344...5.2734e-01, 3.8477e-01, ..., 1.2891e-01,
7.2266e-01, 2.4292e-02]]]], device='cuda:0', dtype=torch.bfloat16)
V = tensor([[[[2.0801e-01, 6.9531e-01, 7.9297e-01, ..., 5.6641e-01,
6.9922e-01, 4.4678e-02],
[8.7891...4.0625e-01, 2.8906e-01, ..., 1.2158e-01,
8.0078e-01, 2.4805e-01]]]], device='cuda:0', dtype=torch.bfloat16)
sm_scale = 0.3
attention_mask = tensor([[[[4.0234e-01, 7.2266e-01, 9.6094e-01, ..., 9.6094e-01,
1.9141e-01, 5.4297e-01],
[3.8672...5.3125e-01, 2.9297e-01, ..., 6.0156e-01,
1.4062e-01, 8.4766e-01]]]], device='cuda:0', dtype=torch.bfloat16)
TMP = tensor([[2.3694e-38, 2.3694e-38, 2.3694e-38, ..., 2.3694e-38, 2.3694e-38,
2.3694e-38],
[0.0000e+00, ...00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00]], device='cuda:0')
output = tensor([[[[0.2080, 0.6953, 0.7930, ..., 0.5664, 0.6992, 0.0447],
[0.6250, 0.3262, 0.8320, ..., 0.4258, 0.3...47],
[0.5391, 0.4727, 0.5273, ..., 0.5352, 0.4629, 0.5781]]]],
device='cuda:0', dtype=torch.bfloat16)
q_batch_stride = 393216, q_head_stride = 8192, q_m_stride = 64, q_k_stride = 1, k_batch_stride = 393216, k_head_stride = 8192, k_n_stride = 64, k_k_stride = 1, v_batch_stride = 393216, v_head_stride = 8192
v_k_stride = 64, v_n_stride = 1, o_batch_stride = 393216, o_head_stride = 8192, o_m_stride = 64, o_n_stride = 1, attention_mask_batch_stride = 786432, attention_mask_head_stride = 16384
attention_mask_m_stride = 128, attention_mask_k_stride = 1, min_clamp_value = -3.3895313892515355e+38, MASK_BATCH_SIZE = 8, MASK_HEAD_SIZE = 48, MASK_M_SIZE = 128, MASK_K_SIZE = 128, HAS_MASK = True
IS_CAUSAL = True, BLOCK_M = 128, BLOCK_DHEAD = 64, BLOCK_N = 128, grid = (1, 384), num_warps = 4, num_stages = 1, extern_libs = None, stream = 0, warmup = False
> ???
E KeyError: ('2-.-0-.-0-1e8410f206c822547fb50e2ea86e45a6-cfed90d463fc30ccffb0eb2fd26372d3-a357695982511d203a134df772c7b4a1-2121719c12e3ab66746f4a57f276d42e-0f76008a374e725ca29ccb33f1ba668f-dc48432b6b79843e2f9c7ad2e7355f59-f40d73592c2578180d3d8e3f64e3957d-0dd03b0bd512a184b3512b278d9dfa59-d7c2e52f8151bec157e9a17a1ec37dd3', ('i32', 'i32', torch.bfloat16, torch.bfloat16, torch.bfloat16, 'fp32', torch.bfloat16, torch.float32, torch.bfloat16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'fp32'), (8, 48, 128, 128, True, True, 128, 64, 128), ((True, False), (True, False), True, True, True, (False,), True, True, True, (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (False,)))
<string>:21: KeyError
During handling of the above exception, another exception occurred:
benchmark = <kernl.benchmark.benchmark_fixture.BenchmarkFixture object at 0x7eff5c2b8e20>, shape = (8, 128), implementation = 'triton', mask_fn = <function generate_bias_mask at 0x7eff5c39d700>
dtype = torch.bfloat16, is_causal = True
@set_seed()
@pytest.mark.parametrize(
"shape",
[(bs, seq_l) for bs in [8, 32, 64] for seq_l in [128, 256, 384, 512]] + [(32, 32)],
ids=lambda x: f"{x[0]}x{x[1]}",
)
# fp32 not yet possible because of a bug in triton
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
@pytest.mark.parametrize("is_causal", [True, False], ids=["causal", "non-causal"])
@pytest.mark.parametrize(
"mask_fn",
[generate_bias_mask, generate_broadcast_mask, generate_none_mask],
ids=["bias-mask", "broadcast-mask", "no-mask"],
)
@pytest.mark.parametrize("implementation", implementations.keys())
def test_benchmark_masked(
benchmark, shape: (int, int), implementation: Callable, mask_fn: Callable, dtype: torch.dtype, is_causal: bool
):
batch, seq_length = shape
if implementation == "original" and (dtype == torch.bfloat16 or seq_length != 512):
pytest.skip("Original Triton implementation only supports fp16 and seq_length=512")
elif implementation == "original" and mask_fn != generate_none_mask:
pytest.skip("Original Triton implementation doesn't support masks")
# batch, heads, seq_length, dhead
mat_shape = (batch, 48, seq_length, 64)
args = {
"q": torch.rand(mat_shape, device="cuda"),
"k": torch.rand(mat_shape, device="cuda"),
"v": torch.rand(mat_shape, device="cuda"),
"output": torch.empty(mat_shape, device="cuda"),
"sm_scale": 0.3, # Scaling applied before softmax (sqrt(dhead) in Vaswani et al.)
"is_causal": is_causal,
"attention_mask": mask_fn(batch, seq_length),
}
expected = attention_reference(**args)
cast_args = {k: v.to(dtype).clone() if isinstance(v, torch.Tensor) else v for k, v in args.items()}
func = implementations[implementation]
> value = benchmark(func, **cast_args)
test/test_attention.py:107:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/kernl/benchmark/benchmark_fixture.py:53: in __call__
function_to_benchmark(*args, **kwargs)
test/test_attention.py:39: in <lambda>
"triton": lambda q, k, v, output, sm_scale, is_causal, attention_mask: attention_forward(
src/kernl/implementations/attention.py:436: in attention_forward
return Attention.apply(q, k, v, output, sm_scale, is_causal, attention_mask)
/usr/local/lib/python3.9/dist-packages/torch/cuda/amp/autocast_mode.py:118: in decorate_fwd
return fwd(*args, **kwargs)
src/kernl/implementations/attention.py:380: in forward
_fwd_kernel[grid](
/usr/local/lib/python3.9/dist-packages/triton/runtime/jit.py:106: in launcher
return self.run(*args, grid=grid, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
heads = 48, seq_length = 128
Q = tensor([[[[3.9844e-01, 5.1562e-01, 2.4902e-02, ..., 6.7969e-01,
4.4141e-01, 5.0000e-01],
[8.0859...6.7188e-01, 1.5918e-01, ..., 3.0859e-01,
7.9688e-01, 5.4297e-01]]]], device='cuda:0', dtype=torch.bfloat16)
K = tensor([[[[3.2422e-01, 9.5703e-01, 3.1055e-01, ..., 5.3516e-01,
4.9023e-01, 1.9531e-01],
[7.7344...5.2734e-01, 3.8477e-01, ..., 1.2891e-01,
7.2266e-01, 2.4292e-02]]]], device='cuda:0', dtype=torch.bfloat16)
V = tensor([[[[2.0801e-01, 6.9531e-01, 7.9297e-01, ..., 5.6641e-01,
6.9922e-01, 4.4678e-02],
[8.7891...4.0625e-01, 2.8906e-01, ..., 1.2158e-01,
8.0078e-01, 2.4805e-01]]]], device='cuda:0', dtype=torch.bfloat16)
sm_scale = 0.3
attention_mask = tensor([[[[4.0234e-01, 7.2266e-01, 9.6094e-01, ..., 9.6094e-01,
1.9141e-01, 5.4297e-01],
[3.8672...5.3125e-01, 2.9297e-01, ..., 6.0156e-01,
1.4062e-01, 8.4766e-01]]]], device='cuda:0', dtype=torch.bfloat16)
TMP = tensor([[2.3694e-38, 2.3694e-38, 2.3694e-38, ..., 2.3694e-38, 2.3694e-38,
2.3694e-38],
[0.0000e+00, ...00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00]], device='cuda:0')
output = tensor([[[[0.2080, 0.6953, 0.7930, ..., 0.5664, 0.6992, 0.0447],
[0.6250, 0.3262, 0.8320, ..., 0.4258, 0.3...47],
[0.5391, 0.4727, 0.5273, ..., 0.5352, 0.4629, 0.5781]]]],
device='cuda:0', dtype=torch.bfloat16)
q_batch_stride = 393216, q_head_stride = 8192, q_m_stride = 64, q_k_stride = 1, k_batch_stride = 393216, k_head_stride = 8192, k_n_stride = 64, k_k_stride = 1, v_batch_stride = 393216, v_head_stride = 8192
v_k_stride = 64, v_n_stride = 1, o_batch_stride = 393216, o_head_stride = 8192, o_m_stride = 64, o_n_stride = 1, attention_mask_batch_stride = 786432, attention_mask_head_stride = 16384
attention_mask_m_stride = 128, attention_mask_k_stride = 1, min_clamp_value = -3.3895313892515355e+38, MASK_BATCH_SIZE = 8, MASK_HEAD_SIZE = 48, MASK_M_SIZE = 128, MASK_K_SIZE = 128, HAS_MASK = True
IS_CAUSAL = True, BLOCK_M = 128, BLOCK_DHEAD = 64, BLOCK_N = 128, grid = (1, 384), num_warps = 4, num_stages = 1, extern_libs = None, stream = 0, warmup = False
> ???
E RuntimeError: Triton Error [CUDA]: invalid argument
<string>:44: RuntimeError