Giter Club home page Giter Club logo

Comments (3)

yunjiangster avatar yunjiangster commented on August 24, 2024

Maybe you can achieve something similar by loading a 2d tensor (via unsqueeze and broadcasting) with overlapping elements. Here is an example:

import triton
import triton.language as tl
import torch

@triton.jit
def test_stencil(x_ptr, o_ptr):
    pid = tl.program_id(axis=0)
    rng = tl.arange(0, 4)
    x = tl.load(x_ptr + rng[:, None] + rng[None, :])
    tl.store(o_ptr + rng, tl.sum(x, axis=1))

x = torch.arange(8).cuda()
y = torch.zeros_like(x)
test_stencil[(1,)](x, y)
x, y

output looks like this

(tensor([0, 1, 2, 3, 4, 5, 6, 7], device='cuda:0'),
 tensor([ 6, 10, 14, 18,  0,  0,  0,  0], device='cuda:0'))

from triton.

thumbe3 avatar thumbe3 commented on August 24, 2024

Hi @yunjiangster, this is definitely a good idea to make code concise. However, it has the same problem of multiple loads. Here is the implementation I did by borrowing your idea. NUM_OFFSETS below is the next power of 2 for (2 * RADIUS + 1)

@triton.jit
def stencil_kernel_v2(
    inputs: tl.tensor,
    outputs: tl.tensor,
    shape: tl.int32,
    BLOCK_SIZE: tl.constexpr,
    NUM_OFFSETS: tl.constexpr,
    RADIUS: tl.constexpr):

    pid = tl.program_id(0)
    # Starting Offsets
    base_offs = tl.arange(0, BLOCK_SIZE) 
    inc_offs = tl.arange(0, NUM_OFFSETS) - RADIUS
    output_offsets = pid * BLOCK_SIZE + base_offs
    input_offsets = pid * BLOCK_SIZE + inc_offs[None, :] + base_offs[:, None]

    input_tensor = tl.load(inputs + input_offsets,
        mask=(input_offsets < shape and input_offsets >=0)
            and inc_offs[None, :] <= RADIUS,
        other=0.0)

    tl.store(outputs + output_offsets,
        tl.sum(input_tensor, axis=1),
        mask = output_offsets < shape)

Even this kernel has similar performance compared to the previous triton kernel and starts to lag behind the hand-tuned cuda implementation on high values of RADIUS

from triton.

yunjiangster avatar yunjiangster commented on August 24, 2024

@thumbe3 ah that makes sense. The load of the same repeated element will still require multiple loading work probably. I am curious how triton handles convolutional kernel then? It’s doing something similar.

from triton.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.