Giter Club home page Giter Club logo

Comments (4)

ptillet avatar ptillet commented on July 22, 2024

Hey! Thanks for your interest :)

Using Triton for TF ops as you did should be totally appropriate. Does your code work with dense inputs by any chance? I would also use a head dimension of 1 for A and B while debugging (i.e., tf.random.uniform([1, 1, 128, 64], dtype=tf.float16)). The look-up table and super-blocking logic can be quite complicated for SDD so my first intuition is that there is some kind of issue there. Did you copy it from one generated with torch for the same layout?

from triton.

jinsolp avatar jinsolp commented on July 22, 2024

Thank you so much for the quick reply!!! 😊
I've been referencing to how Deepspeed uses this kernel, and have given all inputs and layouts with exact match.
With your advice, I used the following configuration;

  • SEQ_LENGTH=128
  • NUM_ATTENTION_HEADS=1
  • BATCH_SIZE=1
  • HIDDEN_SIZE=256
  • BLOCK_SIZE=16
    along with sparse attention "dense" mode, that uses the [matmul kernel].

However, I still have the same problem with using stream->read.
The following is the code I used and the error message that appears;

void compute(drv::context* context,
             drv::stream* stream,
             bool trans_a,  // false
             bool trans_b,  // true
             bool trans_c,  // false
             int block,
             int num_blocks)
{
  size_t dt_nbytes = sizeof(half_float::half);
  drv::device* device = context->device();

  std::vector<half_float::half> ha(1 * 1 * 128 * 256);
  std::vector<half_float::half> hb(1 * 1 * 128 * 256);

  for(size_t i = 0; i < ha.size(); i++) {
    ha[i] = (half_float::half)rand()/RAND_MAX;
  }
  for(size_t i = 0; i < hb.size(); i++) {
    hb[i] = (half_float::half)rand()/RAND_MAX;
  }

  // COPIED FROM PYTORCH TENSOR DATA
  std::vector<int> lut = {
    0,  0,  0,  0,  0,  0,  1,  1,  0,  0,  2,  2,  0,  0,  3,  3,  0,  1,
    0,  8,  0,  1,  1,  9,  0,  1,  2, 10,  0,  1,  3, 11,  0,  2,  0, 16,
    0,  2,  1, 17,  0,  2,  2, 18,  0,  2,  3, 19,  0,  3,  0, 24,  0,  3,
    1, 25,  0,  3,  2, 26,  0,  3,  3, 27,  0,  0,  4,  4,  0,  0,  5,  5,
    0,  0,  6,  6,  0,  0,  7,  7,  0,  1,  4, 12,  0,  1,  5, 13,  0,  1,
    6, 14,  0,  1,  7, 15,  0,  2,  4, 20,  0,  2,  5, 21,  0,  2,  6, 22,
    0,  2,  7, 23,  0,  3,  4, 28,  0,  3,  5, 29,  0,  3,  6, 30,  0,  3,
    7, 31,  0,  4,  0, 32,  0,  4,  1, 33,  0,  4,  2, 34,  0,  4,  3, 35,
    0,  5,  0, 40,  0,  5,  1, 41,  0,  5,  2, 42,  0,  5,  3, 43,  0,  6,
    0, 48,  0,  6,  1, 49,  0,  6,  2, 50,  0,  6,  3, 51,  0,  7,  0, 56,
    0,  7,  1, 57,  0,  7,  2, 58,  0,  7,  3, 59,  0,  4,  4, 36,  0,  4,
    5, 37,  0,  4,  6, 38,  0,  4,  7, 39,  0,  5,  4, 44,  0,  5,  5, 45,
    0,  5,  6, 46,  0,  5,  7, 47,  0,  6,  4, 52,  0,  6,  5, 53,  0,  6,
    6, 54,  0,  6,  7, 55,  0,  7,  4, 60,  0,  7,  5, 61,  0,  7,  6, 62,
    0,  7,  7, 63
  };
  std::vector<float> widths = {4};
  std::vector<float> spdims = {4, 8, 8};
  std::vector<float> packs = {4};
  
  int AS0 = 1; int AS1 = 1 ; int AS2 = 128 ; int AS3 = 256;
  int BS0 = 1; int BS1 = 1 ; int BS2 = 256 ; int BS3 = 128;

  bool is_16_multiple = (AS3 % 16 == 0);
  bool is_32_multiple = (AS3 % 32 == 0);
  bool is_64_multiple = (AS3 % 64 == 0);

  int total_width = 0;
  int num_elements =  widths.size();

  for(int i = 0 ; i < num_elements ; i++) {
    total_width += widths[i] * packs[i] * packs[i];
  }

  std::vector<half_float::half> hc(AS0 * total_width * block * block);

  auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, AS0*AS1*AS2*AS3*dt_nbytes));
  auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, BS0*BS1*BS2*BS3*dt_nbytes));
  auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, AS0*total_width*block*block*sizeof(int)));

  stream->write(&*da, true, 0, ha);
  stream->write(&*db, true, 0, hb);
  
  int pack = packs[0];
  
  rt::options_t opt;
  opt.defines["TM"] = std::to_string(block * pack);
  opt.defines["TN"] = std::to_string(block * pack);
  opt.defines["TMN"] = std::to_string(block * block * pack * pack);
  opt.defines["BLOCK"] = std::to_string(block);
  opt.defines["TK"] = "32";
  opt.defines["TYPE"] = "half"; 
  opt.defines["STRIDE_AM"] = "lda";   //trans_a ? "1" : "lda";  //TODO:
  opt.defines["STRIDE_AK"] = "1";     //trans_a ? "lda" : "1";
  opt.defines["STRIDE_BN"] = "ldb";   //trans_b ? "ldb" : "1";
  opt.defines["STRIDE_BK"] = "1";     //trans_b ? "1" : "ldb";
  opt.defines["STRIDE_CM"] = "ldc";
  opt.defines["STRIDE_CN"] = "1";
  opt.defines["SDD"] = "True"; 
  opt.defines["TZ"] = "1";
  opt.defines["NAME"] = "sdd_kernel";
  opt.num_warps = 4;

  int num_lock = 1;
  int width = widths[0];

  std::vector<int> locks(2 * width * AS0 * num_lock, 0);
  auto locks_buf = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 2*width*AS0*num_lock*sizeof(int)));
  stream->write(&*locks_buf, true, 0, locks);

  auto lut_buf = std::shared_ptr<drv::buffer>(drv::buffer::create(context, lut.size()*sizeof(int)));
  stream->write(&*lut_buf, true, 0, lut);

  std::stringstream oss;
  rt::add_arg(oss, *da->cu());
  rt::add_arg(oss, *db->cu());
  rt::add_arg(oss, *dc->cu());
  rt::add_arg(oss, 64);     // a.stride(2)
  rt::add_arg(oss, 64);     // b.stride(2)
  rt::add_arg(oss, block);
  rt::add_arg(oss, 98304);  // a.stride(0)
  rt::add_arg(oss, 98304);  // b.stride(0)
  rt::add_arg(oss, 16384);  // c.stride(0)
  rt::add_arg(oss, 256);   // a.stride(1)
  rt::add_arg(oss, 256);   // b.stride(1)  
  rt::add_arg(oss, 16384);  // c.stride(0)
  rt::add_arg(oss, AS2);
  rt::add_arg(oss, AS2);
  rt::add_arg(oss, AS3);
  rt::add_arg(oss, 0);
  rt::add_arg(oss, *lut_buf->cu());
  rt::add_arg(oss, *locks_buf->cu());
  rt::add_arg(oss, num_lock);

  rt::function function(src::matmul, opt, device);

  int max_width = 49152;

  auto grid = [max_width, width, AS0](const rt::options_t& x) {
    return rt::grid_t{(size_t)x.D<int>("TZ"),
                (size_t)std::min(max_width, width),
                (size_t)AS0};
  };

  function((void**)oss.str().data(), oss.str().size(), grid, stream);    // prints function::operator() debugging outputs 
  stream->read(&*dc, true, 0, hc);  // ERROR HAPPENS HERE
}

int main() {
  auto context = triton::driver::backend::contexts::get_default();
  triton::driver::stream* stream = triton::driver::stream::create(context->backend());

  compute(context, stream, false, true, false, 16, 0);
}

The printed lines that come before the error were added in triton code (e.g. function.cc) to check the flow : )

function constructor - called
  init_kernels - called
    kernel constructor - called
      kernel::init_ir - getting macros: [NAME] sdd_kernel
      kernel::init_ir - getting macros: [TZ] 1
      kernel::init_ir - getting macros: [STRIDE_CM] ldc
      kernel::init_ir - getting macros: [SDD] True
      kernel::init_ir - getting macros: [STRIDE_CN] 1
      kernel::init_ir - getting macros: [STRIDE_BN] ldb
      kernel::init_ir - getting macros: [STRIDE_BK] 1
      kernel::init_ir - getting macros: [STRIDE_AK] 1
      kernel::init_ir - getting macros: [STRIDE_AM] lda
      kernel::init_ir - getting macros: [TMN] 4096
      kernel::init_ir - getting macros: [BLOCK] 16
      kernel::init_ir - getting macros: [TN] 64
      kernel::init_ir - getting macros: [TYPE] half
      kernel::init_ir - getting macros: [TM] 64
      kernel::init_ir - getting macros: [TK] 32
      init_ker - called
         allocated size: 16384 
         max_shared_memory: 49152
      init_ker - returning
      init_sig - called
      init_sig - returning
    kernel constructor - returning
  init_kernels - returning
function constructor - returning
function::operator() - called
function::operator() - autotune finished
function::operator() - kernel is successfully called. now returning...
terminate called after throwing an instance of 'triton::driver::exception::cuda::illegal_address'
  what():  CUDA: Error- illegal address
[1]    19572 abort (core dumped)  ./sdd_matmul

Thank you! :)

from triton.

jinsolp avatar jinsolp commented on July 22, 2024

After writing the issue above, I further examined the code of Triton-PyTorch binding. I am guessing the illegal memory access happens because of the arguments for stride (stride_za, stride_hb, stride_zc, stride_ha, stride_hb, stride_hc).
Pytorch assigns a larger chunk of memory with their own logic when creating a new tensor, so assumingly, there must be a difference in allocated memory spaces when using tensors declared with the TF framework. Thus, by merely using the stride values used in pytorch, I believe there maybe cases of illegal memory space accessing.

Therefore, I tried to understand what the parameters (lda, ldb, ldc, stride_za, stride_hb, stride_zc, stride_ha, stride_hb, stride_hc.... ) passed on to the [matmul kernel] are actually doing inside the kernel code but am having trouble 😭

It would be so great if you could give me a hand in understanding the high-level roles of the aforementioned parameters.
Thank you in advance 😄

from triton.

ptillet avatar ptillet commented on July 22, 2024

I'm sorry I forgot to follow up with this. Triton switched to a Python-only API that is now definitely incompatible with Tensorflow. Support for tensorflow will be added as soon as it supports defining custom ops in Python -- if it ever happens.

Thanks!

from triton.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.