Comments (4)
Hey! Thanks for your interest :)
Using Triton for TF ops as you did should be totally appropriate. Does your code work with dense inputs by any chance? I would also use a head dimension of 1 for A and B while debugging (i.e., tf.random.uniform([1, 1, 128, 64], dtype=tf.float16)
). The look-up table and super-blocking logic can be quite complicated for SDD so my first intuition is that there is some kind of issue there. Did you copy it from one generated with torch for the same layout?
from triton.
Thank you so much for the quick reply!!! 😊
I've been referencing to how Deepspeed uses this kernel, and have given all inputs and layouts with exact match.
With your advice, I used the following configuration;
SEQ_LENGTH=128
NUM_ATTENTION_HEADS=1
BATCH_SIZE=1
HIDDEN_SIZE=256
BLOCK_SIZE=16
along with sparse attention"dense"
mode, that uses the [matmul kernel].
However, I still have the same problem with using stream->read
.
The following is the code I used and the error message that appears;
void compute(drv::context* context,
drv::stream* stream,
bool trans_a, // false
bool trans_b, // true
bool trans_c, // false
int block,
int num_blocks)
{
size_t dt_nbytes = sizeof(half_float::half);
drv::device* device = context->device();
std::vector<half_float::half> ha(1 * 1 * 128 * 256);
std::vector<half_float::half> hb(1 * 1 * 128 * 256);
for(size_t i = 0; i < ha.size(); i++) {
ha[i] = (half_float::half)rand()/RAND_MAX;
}
for(size_t i = 0; i < hb.size(); i++) {
hb[i] = (half_float::half)rand()/RAND_MAX;
}
// COPIED FROM PYTORCH TENSOR DATA
std::vector<int> lut = {
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 2, 2, 0, 0, 3, 3, 0, 1,
0, 8, 0, 1, 1, 9, 0, 1, 2, 10, 0, 1, 3, 11, 0, 2, 0, 16,
0, 2, 1, 17, 0, 2, 2, 18, 0, 2, 3, 19, 0, 3, 0, 24, 0, 3,
1, 25, 0, 3, 2, 26, 0, 3, 3, 27, 0, 0, 4, 4, 0, 0, 5, 5,
0, 0, 6, 6, 0, 0, 7, 7, 0, 1, 4, 12, 0, 1, 5, 13, 0, 1,
6, 14, 0, 1, 7, 15, 0, 2, 4, 20, 0, 2, 5, 21, 0, 2, 6, 22,
0, 2, 7, 23, 0, 3, 4, 28, 0, 3, 5, 29, 0, 3, 6, 30, 0, 3,
7, 31, 0, 4, 0, 32, 0, 4, 1, 33, 0, 4, 2, 34, 0, 4, 3, 35,
0, 5, 0, 40, 0, 5, 1, 41, 0, 5, 2, 42, 0, 5, 3, 43, 0, 6,
0, 48, 0, 6, 1, 49, 0, 6, 2, 50, 0, 6, 3, 51, 0, 7, 0, 56,
0, 7, 1, 57, 0, 7, 2, 58, 0, 7, 3, 59, 0, 4, 4, 36, 0, 4,
5, 37, 0, 4, 6, 38, 0, 4, 7, 39, 0, 5, 4, 44, 0, 5, 5, 45,
0, 5, 6, 46, 0, 5, 7, 47, 0, 6, 4, 52, 0, 6, 5, 53, 0, 6,
6, 54, 0, 6, 7, 55, 0, 7, 4, 60, 0, 7, 5, 61, 0, 7, 6, 62,
0, 7, 7, 63
};
std::vector<float> widths = {4};
std::vector<float> spdims = {4, 8, 8};
std::vector<float> packs = {4};
int AS0 = 1; int AS1 = 1 ; int AS2 = 128 ; int AS3 = 256;
int BS0 = 1; int BS1 = 1 ; int BS2 = 256 ; int BS3 = 128;
bool is_16_multiple = (AS3 % 16 == 0);
bool is_32_multiple = (AS3 % 32 == 0);
bool is_64_multiple = (AS3 % 64 == 0);
int total_width = 0;
int num_elements = widths.size();
for(int i = 0 ; i < num_elements ; i++) {
total_width += widths[i] * packs[i] * packs[i];
}
std::vector<half_float::half> hc(AS0 * total_width * block * block);
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, AS0*AS1*AS2*AS3*dt_nbytes));
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, BS0*BS1*BS2*BS3*dt_nbytes));
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, AS0*total_width*block*block*sizeof(int)));
stream->write(&*da, true, 0, ha);
stream->write(&*db, true, 0, hb);
int pack = packs[0];
rt::options_t opt;
opt.defines["TM"] = std::to_string(block * pack);
opt.defines["TN"] = std::to_string(block * pack);
opt.defines["TMN"] = std::to_string(block * block * pack * pack);
opt.defines["BLOCK"] = std::to_string(block);
opt.defines["TK"] = "32";
opt.defines["TYPE"] = "half";
opt.defines["STRIDE_AM"] = "lda"; //trans_a ? "1" : "lda"; //TODO:
opt.defines["STRIDE_AK"] = "1"; //trans_a ? "lda" : "1";
opt.defines["STRIDE_BN"] = "ldb"; //trans_b ? "ldb" : "1";
opt.defines["STRIDE_BK"] = "1"; //trans_b ? "1" : "ldb";
opt.defines["STRIDE_CM"] = "ldc";
opt.defines["STRIDE_CN"] = "1";
opt.defines["SDD"] = "True";
opt.defines["TZ"] = "1";
opt.defines["NAME"] = "sdd_kernel";
opt.num_warps = 4;
int num_lock = 1;
int width = widths[0];
std::vector<int> locks(2 * width * AS0 * num_lock, 0);
auto locks_buf = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 2*width*AS0*num_lock*sizeof(int)));
stream->write(&*locks_buf, true, 0, locks);
auto lut_buf = std::shared_ptr<drv::buffer>(drv::buffer::create(context, lut.size()*sizeof(int)));
stream->write(&*lut_buf, true, 0, lut);
std::stringstream oss;
rt::add_arg(oss, *da->cu());
rt::add_arg(oss, *db->cu());
rt::add_arg(oss, *dc->cu());
rt::add_arg(oss, 64); // a.stride(2)
rt::add_arg(oss, 64); // b.stride(2)
rt::add_arg(oss, block);
rt::add_arg(oss, 98304); // a.stride(0)
rt::add_arg(oss, 98304); // b.stride(0)
rt::add_arg(oss, 16384); // c.stride(0)
rt::add_arg(oss, 256); // a.stride(1)
rt::add_arg(oss, 256); // b.stride(1)
rt::add_arg(oss, 16384); // c.stride(0)
rt::add_arg(oss, AS2);
rt::add_arg(oss, AS2);
rt::add_arg(oss, AS3);
rt::add_arg(oss, 0);
rt::add_arg(oss, *lut_buf->cu());
rt::add_arg(oss, *locks_buf->cu());
rt::add_arg(oss, num_lock);
rt::function function(src::matmul, opt, device);
int max_width = 49152;
auto grid = [max_width, width, AS0](const rt::options_t& x) {
return rt::grid_t{(size_t)x.D<int>("TZ"),
(size_t)std::min(max_width, width),
(size_t)AS0};
};
function((void**)oss.str().data(), oss.str().size(), grid, stream); // prints function::operator() debugging outputs
stream->read(&*dc, true, 0, hc); // ERROR HAPPENS HERE
}
int main() {
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
compute(context, stream, false, true, false, 16, 0);
}
The printed lines that come before the error were added in triton code (e.g. function.cc
) to check the flow : )
function constructor - called
init_kernels - called
kernel constructor - called
kernel::init_ir - getting macros: [NAME] sdd_kernel
kernel::init_ir - getting macros: [TZ] 1
kernel::init_ir - getting macros: [STRIDE_CM] ldc
kernel::init_ir - getting macros: [SDD] True
kernel::init_ir - getting macros: [STRIDE_CN] 1
kernel::init_ir - getting macros: [STRIDE_BN] ldb
kernel::init_ir - getting macros: [STRIDE_BK] 1
kernel::init_ir - getting macros: [STRIDE_AK] 1
kernel::init_ir - getting macros: [STRIDE_AM] lda
kernel::init_ir - getting macros: [TMN] 4096
kernel::init_ir - getting macros: [BLOCK] 16
kernel::init_ir - getting macros: [TN] 64
kernel::init_ir - getting macros: [TYPE] half
kernel::init_ir - getting macros: [TM] 64
kernel::init_ir - getting macros: [TK] 32
init_ker - called
allocated size: 16384
max_shared_memory: 49152
init_ker - returning
init_sig - called
init_sig - returning
kernel constructor - returning
init_kernels - returning
function constructor - returning
function::operator() - called
function::operator() - autotune finished
function::operator() - kernel is successfully called. now returning...
terminate called after throwing an instance of 'triton::driver::exception::cuda::illegal_address'
what(): CUDA: Error- illegal address
[1] 19572 abort (core dumped) ./sdd_matmul
Thank you! :)
from triton.
After writing the issue above, I further examined the code of Triton-PyTorch binding. I am guessing the illegal memory access happens because of the arguments for stride (stride_za
, stride_hb
, stride_zc
, stride_ha
, stride_hb
, stride_hc
).
Pytorch assigns a larger chunk of memory with their own logic when creating a new tensor, so assumingly, there must be a difference in allocated memory spaces when using tensors declared with the TF framework. Thus, by merely using the stride values used in pytorch, I believe there maybe cases of illegal memory space accessing.
Therefore, I tried to understand what the parameters (lda
, ldb
, ldc
, stride_za
, stride_hb
, stride_zc
, stride_ha
, stride_hb
, stride_hc
.... ) passed on to the [matmul kernel] are actually doing inside the kernel code but am having trouble 😭
It would be so great if you could give me a hand in understanding the high-level roles of the aforementioned parameters.
Thank you in advance 😄
from triton.
I'm sorry I forgot to follow up with this. Triton switched to a Python-only API that is now definitely incompatible with Tensorflow. Support for tensorflow will be added as soon as it supports defining custom ops in Python -- if it ever happens.
Thanks!
from triton.
Related Issues (20)
- https://github.com/Cecil500 HOT 3
- triton cache does not invalidate cache correctly when dynamically choosing a function to call HOT 10
- Why change the order of make_block_ptr when V.dtype.element_ty == tl.float8e5?
- Print statements inside kernel print incorrect value of int64 tensors HOT 4
- batched matrix multiplication within a program HOT 2
- urllib.error.HTTPError: HTTP Error 404: Not Found HOT 1
- Question about memory coalescing HOT 1
- For small size M, like the shape M=1 K=5120 N=1792, how to improve the performance with triton? HOT 3
- github tag is not consistent with pypi version
- Calling torch.compile fails when Triton kernel arguments include triton.language.dtype HOT 1
- tl.cumsum seems emitting an internal error. HOT 1
- How to perform a store operation on a part of a Tensor? HOT 1
- Question regarding stride HOT 1
- [AMD] Undefined behavior sanitizer invalid-bool-load in optimize_epilogue.mlir HOT 5
- int8 x bfloat16 matmul tests fail on 4090s due to numerical error
- Incorrect result with threadsPerWarp of [2, 2, 8] for a thread block of [2, 2, 32] HOT 16
- M2 Mac Build from Source Failure: MLIR Configuration Error HOT 1
- RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument HOT 2
- StackTrace handler on python module does not allow signal to propagate. HOT 5
- Associative scan with non-scalar inputs
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from triton.