Comments (8)
That would actually be amazing. At the moment though, I have very little knowledge on how to auto-diff C-like code. Are you aware of any project that does something similar that I could take a look at out of curiosity?
Edit: I have found clad: https://github.com/vgvassilev/clad/. I think that such an autodiff for Triton would be outside the scope of Triton -- much like clad is outside of clang. Ultimately, Triton only intends to be a lightweight replacement for CUDA when writing DNN ops.
from triton.
As I know, this project can do autodiff for LLVM. https://github.com/wsmoses/Enzyme.
@ptillet Does it help?
from triton.
Yes, this helps. However, I feel like this is a bit out of the scope of Triton, which aims just to be a replacement for CUDA. My intuition is that automatic differentiation of CUDA kernel could be a bit hard in the general case, i.e. when atomics are used or more generally when different kernel instances touch the same memory location. For example, it is not always the case that a single forward propagation kernel can be auto-differentiated into a single kernel.
from triton.
https://dspace.mit.edu/handle/1721.1/122623 is a paper which does this for gpu.
from triton.
@ptillet Enzyme maintainer here, we actually can do GPU these days - have been intrigued to see if Enzyme can handle Triton without adding new custom primitives for the adjoint generation. From having read the paper, and sifted over the code, Triton is essentially generating a reduced subset of LLVM IR?
Is there a switch inside of Triton to inspect the generated IR? Would be happy to give a prototype a go, but agree that it wouldn't necessarily make sense to really pull automatic differentiation inside of Triton itself. Would just be curious to see if Enzyme works on Triton-generated kernels.
from triton.
@ludgerpaehler Thanks for reaching out :) I think my main worry isn't so much whether it can be done, but how efficiently it can be done. I've thought about the problem a bit, and I don't think something like matmul could be auto-differentiated into an optimal backward pass -- let alone flash attention. There would likely still be some utility to having a slow-ish backward pass though, as long as it's not 5x slower.
Triton does generate LLVM-IR (NVPTX) code eventually, but I am not sure about the performance of any approach that would try to auto-differentiate that code automatically, as the compiler has no way to know whether some shared memory (or global memory) values are used by different threads. One thing that could potentially work would be to auto-diff Triton-IR directly, but even there the existence of pointers in an SPMD model makes me think we'd have to at least double the amount of memory i/o (i.e., backward pass of pointers loads become atomic adds?) although we may retain optimal shared memory / tensor cores utilization. That still could have good enough performance to be useful.
I think long term, what would be absolutely bonkers would be something like Enzyme capable of operating on MLIR dialects (provided some interface, of course). Is this an idea that has been thrown around?
from triton.
@ptillet I don't want to say too much and overpromise, but it is something we are looking at very closely and something which we are very keen to pursue. Are you going to be at the US LLVM Dev meeting in San Jose? Would be happy to discuss more in-depth in person (or via mail) :)
from triton.
Sorry for the delay. Things have been very busy :) I won't be at the LLVM Dev meeting as it is around the time I'm traveling to France to see my family. Happy to talk more about it via mail or on a call; feel free to shoot me an email at [email protected]
from triton.
Related Issues (20)
- Why change the order of make_block_ptr when V.dtype.element_ty == tl.float8e5?
- Print statements inside kernel print incorrect value of int64 tensors HOT 4
- batched matrix multiplication within a program HOT 2
- urllib.error.HTTPError: HTTP Error 404: Not Found HOT 1
- Question about memory coalescing HOT 1
- For small size M, like the shape M=1 K=5120 N=1792, how to improve the performance with triton? HOT 3
- github tag is not consistent with pypi version
- Calling torch.compile fails when Triton kernel arguments include triton.language.dtype HOT 1
- tl.cumsum seems emitting an internal error. HOT 1
- How to perform a store operation on a part of a Tensor? HOT 1
- Question regarding stride HOT 1
- [AMD] Undefined behavior sanitizer invalid-bool-load in optimize_epilogue.mlir HOT 5
- int8 x bfloat16 matmul tests fail on 4090s due to numerical error
- Incorrect result with threadsPerWarp of [2, 2, 8] for a thread block of [2, 2, 32] HOT 16
- M2 Mac Build from Source Failure: MLIR Configuration Error HOT 1
- RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument HOT 2
- StackTrace handler on python module does not allow signal to propagate. HOT 5
- Associative scan with non-scalar inputs
- how to use transpose in pytorch HOT 1
- Support masking in atomic_cas?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from triton.