Giter Club home page Giter Club logo

extending-jax's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

extending-jax's Issues

How would I include cuDNN?

First of all, thank you very much for this. You have saved me tons of work and I am very grateful for the great documentation.

I would like to extend JAX with custom calls that internally make use of cudnn. For this I added an include at the top of "kernels.cc.cu". I tried both of the following:

#include <cudnn.h>
#include "/usr/include/cudnn.h"

The compiler finds the header and does not complain when I add the following host code:

  cudnnHandle_t handle_;
  cudnnCreate(&handle_);

However as soon as I try to run the code from JAX, I get the error that cudnnCreate is an undefined symbol. If I remove the includes then the compiler complains.

Do you have any idea how I could potentially fix this?

nit on the need for transpose rules

I love the tutorial, thanks for putting together this great resource!

You write:

We're lucky in this case, and we don't need to add a "transpose rule", since JAX can actually work that out by itself (our JVP is linear in the tangents).

Every well behaved JVP is linear in the tangents, by definition -- the tangents are the "vector" part of "Jacobian-vector product."

What's special about this primitive (and which means that we don't need the transpose rule) is that it is non-linear. That means it that it can't appear in the tangent calculation (because again, output tangents are always a linear function of input tangents, per the chain rule), and only things that appear in the tangent calculations needs to be transposed to calculate cotangents (VJPs).

So the key question is actually whether a primitive is a linear function of one or more of its arguments. If so, then yes, you need a transpose rule to support reverse mode autodiff.

Updates to support jax>0.4.14

@dfm Firstly, thank you for the write-up which has been very helpful to read through!

I've noticed that the code will no longer run on jax versions > 0.4.14 (currently 0.4.19). Having worked through the issues (which are all fairly minor) I would be happy to submit a PR to update the repo.

Briefly, these consist of:

  • deprecated register_cpu_custom_call_target function
  • changes to the api and return values for the jax helper custom_call function
  • ShapedArray has moved to jax.core for import
  • I would also like to tag the jax version installed as the colab notebook does not currently run (and the jax api seems to be continuing to evolve)

XLA register translation rule fail

Hey, i try to add custom call and define the xla translation rule follow this doc https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#xla-compilation-rules

However, it miss the custom call part. And i try to implement this part follow your example code.

You use functools.partial(translation, platform="cpu) here, however, i got func_wrapper() got an unexpected keyword argument 'platform'
https://github.com/dfm/extending-jax/blob/main/src/kepler_jax/kepler_jax.py#L190-L193

Could you please give some suggestions?

P.S. I use Cython to implement the C++ XLA custom call function. And the only remaining part is register the xla translation rule.

Interested in help updating these instructions to the new style of XLA translation rules?

Hey @dfm :-) I had to do a bit of hacking this week, read a lot of XLA source code, and played with the new MLIR approach for specifying translation rules for custom ops in JAX. My understanding is that since April, all builtin LAX primitives have been transferred to MLIR equivalents, and that the old style CustomCallWithLayout just remains for backward compatibility. Here is an example of what the custom calls look like in current jax:
https://github.com/google/jax/blob/f697b8e0876f8e1144a53ace02ee6d7eaa43fa14/jaxlib/gpu_solver.py#L66

Before the knowledge of how to make these things work leaves my short term memory, would you be interested in something like a PR to this post? If you prefer these posts to stay static, no worries, I can write down that info elsewhere, linking to your post for extended context ;-)

How to reinterpret_cast a matrix?

Dear @dfm , your tutorial is excellent. But I am not familar with c++. I have a naive question.

You said that how to receive input values by:

#include <cstdint> // int64_t

template <typename T>
void cpu_kepler(void *out, const void **in) {
const std::int64_t size = *reinterpret_cast<const std::int64_t *>(in[0]);
const T *mean_anom = reinterpret_cast<const T *>(in[1]);
const T *ecc = reinterpret_cast<const T *>(in[2]);
}

However, if one of my input values is a matrix, how can I reinterpret_cast it?

Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.