dfm / extending-jax Goto Github PK
View Code? Open in Web Editor NEWExtending JAX with custom C++ and CUDA code
License: MIT License
Extending JAX with custom C++ and CUDA code
License: MIT License
First of all, thank you very much for this. You have saved me tons of work and I am very grateful for the great documentation.
I would like to extend JAX with custom calls that internally make use of cudnn. For this I added an include at the top of "kernels.cc.cu". I tried both of the following:
#include <cudnn.h>
#include "/usr/include/cudnn.h"
The compiler finds the header and does not complain when I add the following host code:
cudnnHandle_t handle_;
cudnnCreate(&handle_);
However as soon as I try to run the code from JAX, I get the error that cudnnCreate
is an undefined symbol. If I remove the includes then the compiler complains.
Do you have any idea how I could potentially fix this?
I love the tutorial, thanks for putting together this great resource!
You write:
We're lucky in this case, and we don't need to add a "transpose rule", since JAX can actually work that out by itself (our JVP is linear in the tangents).
Every well behaved JVP is linear in the tangents, by definition -- the tangents are the "vector" part of "Jacobian-vector product."
What's special about this primitive (and which means that we don't need the transpose rule) is that it is non-linear. That means it that it can't appear in the tangent calculation (because again, output tangents are always a linear function of input tangents, per the chain rule), and only things that appear in the tangent calculations needs to be transposed to calculate cotangents (VJPs).
So the key question is actually whether a primitive is a linear function of one or more of its arguments. If so, then yes, you need a transpose rule to support reverse mode autodiff.
@dfm Firstly, thank you for the write-up which has been very helpful to read through!
I've noticed that the code will no longer run on jax versions > 0.4.14 (currently 0.4.19). Having worked through the issues (which are all fairly minor) I would be happy to submit a PR to update the repo.
Briefly, these consist of:
register_cpu_custom_call_target
functioncustom_call
functionShapedArray
has moved to jax.core
for importHey, i try to add custom call and define the xla translation rule follow this doc https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#xla-compilation-rules
However, it miss the custom call part. And i try to implement this part follow your example code.
You use functools.partial(translation, platform="cpu)
here, however, i got func_wrapper() got an unexpected keyword argument 'platform'
https://github.com/dfm/extending-jax/blob/main/src/kepler_jax/kepler_jax.py#L190-L193
Could you please give some suggestions?
P.S. I use Cython to implement the C++ XLA custom call function. And the only remaining part is register the xla translation rule.
Hey @dfm :-) I had to do a bit of hacking this week, read a lot of XLA source code, and played with the new MLIR approach for specifying translation rules for custom ops in JAX. My understanding is that since April, all builtin LAX primitives have been transferred to MLIR equivalents, and that the old style CustomCallWithLayout
just remains for backward compatibility. Here is an example of what the custom calls look like in current jax:
https://github.com/google/jax/blob/f697b8e0876f8e1144a53ace02ee6d7eaa43fa14/jaxlib/gpu_solver.py#L66
Before the knowledge of how to make these things work leaves my short term memory, would you be interested in something like a PR to this post? If you prefer these posts to stay static, no worries, I can write down that info elsewhere, linking to your post for extended context ;-)
Dear @dfm , your tutorial is excellent. But I am not familar with c++. I have a naive question.
You said that how to receive input values by:
#include <cstdint> // int64_t
template <typename T>
void cpu_kepler(void *out, const void **in) {
const std::int64_t size = *reinterpret_cast<const std::int64_t *>(in[0]);
const T *mean_anom = reinterpret_cast<const T *>(in[1]);
const T *ecc = reinterpret_cast<const T *>(in[2]);
}
However, if one of my input values is a matrix, how can I reinterpret_cast
it?
Thanks.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.