Comments (2)
Could the connections be added via Requires and then upstreamed?
from abstractdifferentiation.jl.
We want AD package developers to depend on AbstractDifferentiation to define one primitive operation (and potentially the primal_value
if needed). Currently, it is not implemented in any lower-level AD package yet. So, at the moment, you also need to add the primitive operation. That is missing in your attempt above.
using AbstractDifferentiation
import Zygote
import ForwardDiff
# test function
foo(x) = sin(x[1]) + prod(x[2:end].^2)
x = rand(4)
# direct usage works
Zygote.gradient(foo, x)[1]
ForwardDiff.gradient(foo, x)
# correct way to use a backend
struct ForwardDiffBackend1 <: AD.AbstractForwardMode end
const forwarddiff_backend1 = ForwardDiffBackend1()
AD.@primitive function pushforward_function(ab::ForwardDiffBackend1, f, xs...)
# jvp = f'(x)*v, i.e., differentiate f(x + h*v) wrt h at 0
return function (vs)
if xs isa Tuple
@assert length(xs) <= 2
if length(xs) == 1
(ForwardDiff.derivative(h->f(xs[1]+h*vs[1]),0),)
else
ForwardDiff.derivative(h->f(xs[1]+h*vs[1], xs[2]+h*vs[2]),0)
end
else
ForwardDiff.derivative(h->f(xs+h*vs),0)
end
end
end
AD.primal_value(::ForwardDiffBackend1, ::Any, f, xs) = ForwardDiff.value.(f(xs...))
struct ZygoteBackend1 <: AD.AbstractReverseMode end
const zygote_backend1 = ZygoteBackend1()
AD.@primitive function pullback_function(ab::ZygoteBackend1, f, xs...)
return function (vs)
# Supports only single output
_, back = Zygote.pullback(f, xs...)
if vs isa AbstractVector
back(vs)
else
@assert length(vs) == 1
back(vs[1])
end
end
end
AD.gradient(zygote_backend1, foo, x)
AD.gradient(forwarddiff_backend1, foo, x)
Also,iIt would be nice if I could list all available backends somehow.
Absolutely! The next steps include adding more primitives for different AD packages as test cases to check if all macros etc., in AbstractDifferentiation are fine. Then, AD package developers could go ahead and add these primitives to their package (as it is done above). Afterwards, end-users would only have to call
const zygote_backend = ZygoteBackend(),
where the primitives are then defined in Zygote.jl, and get all listed "derivative operations" like Jacobian
from the API of AbstractDifferentiation.
from abstractdifferentiation.jl.
Related Issues (20)
- Caching interfaces HOT 1
- Adopt ColPrac?
- Establish a benchmark for performance regressions/improvements
- value_gradient_and_hessian for ForwardDiff returns gradient of type Dual HOT 1
- Type of cotangent is sometimes not a vector in FiniteDifferences vjp tests
- Use multiple arguments instead of a tuple for pushforward and pullback function? HOT 4
- `AD.jacobian` much slower than `Zygote.jacobian`
- JET.jl reports possible errors on `AD.gradient` that do not appear with `Zygote.gradient` HOT 3
- Code inside function with rrule should not run HOT 2
- Using AbstractDifferentiation and Zygote as package dependencies errors: `UndefVarError: ZygoteBackend not defined`
- Jacobians for functions beyond vector-to-vector?
- Whats the reason for the derivative / gradient difference?
- AD failure where Zygote succeeds HOT 3
- ForwardDiff is broken for differently sized inputs
- API for user code to detect if it's being differentiated HOT 11
- Feature request: option to turn off ForwardDiff tagging
- Zygote context cache incorrectly(?) persists between AD calls
- Missing `value_derivative_and_second_derivative`?
- Replace Requires.jl with conditional dependencies for Julia 1.9 HOT 2
- LoadError: UndefVarError: StaticArrays not defined HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from abstractdifferentiation.jl.