Comments (16)
Zygote also works. But yeah, ReverseDiff is one I should also get working because it’s still used in a few places. I’m not sure if ReverseDiff uses ChainRules, but I’ve also been meaning to add that so I can cover more bases. I’ll try to get to this tonight.
from componentarrays.jl.
This is looking a little more difficult than I was expecting it to be. Supposedly there are plans to migrate ReverseDiff to use ChainRules, so I think it would probably be best to wait for that. In the meantime, doing a custom gradient with Zygote like gradsZ = Zygote.gradient(F, ca)[1].x
is probably the way to go for things like Optim.
from componentarrays.jl.
What's the status of this? Does it work now? Last time I checked, ReverseDiff
seemed to support ChainRules
, though admittedly I'm still a bit hazy on how that all works.
from componentarrays.jl.
Well the good news is I updated my autodiff stuff to use ChainRules instead of Zygote directly. The bad news is ReverseDiff still doesn't use ChainRules. I'm looking into what it's going to take to make this work, but the ReverseDiff docs don't give me a ton to work with for figuring out how to make pullback rules. The only thing I can think of is overloading each derivative function (gradient
, gradient!
, jacobian
, jacobian!
, ...) for getproperty
and getindex
of a ReverseDiff.TrackedArray{V,D,N,<:ComponentArray,DA} where {V,D,N,DA}
. If that's the case, I may need some help in the form of a PR or it will be a while until I can get to it.
from componentarrays.jl.
See #67
from componentarrays.jl.
Yeah so I kind of need this, possibly soon. What can I do to help?
from componentarrays.jl.
Although I actually don't necessarily need my parameter array to be a ComponentArray
. I guess I should check if non-tracked arrays work with ComponentArrays.
from componentarrays.jl.
It's looking like we might be able to define getindex
for a tracked ComponentArray
and a Symbol
or Val
taking some inspiration from here
using ComponentArrays, ReverseDiff
using ReverseDiff: TrackedArray, value, deriv, tape, SpecialInstruction, record!
const TrackedComponentArray{V, D, N, DA, N, A, Ax} = TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA}
function Base.getindex(t::TrackedComponentArray, ind::Union{Symbol, Val})
tp = tape(t)
out = TrackedArray(value(t)[ind], deriv(t)[ind], tp)
record!(tp, SpecialInstruction, (getindex, Val(:generic)), (t, ind), out)
return out
end
This is enough to get the tracking to work through the getindex
operation (getproperty
is probably not a good idea because TrackedArray
needs to be able to access its fields with that), but actually taking derivatives, gradients, jacobians, etc. requires reaching into the special_reverse_exec!
code here. Or maybe just making a separate dispatch for something other than Val(:generic)
and patch in from there. I haven't had luck getting that working yet, though.
from componentarrays.jl.
So without getproperty
we would only be able to do carray[:var]
right?
from componentarrays.jl.
I guess I don't really know what special_reverse_exec
is doing. What is SpecialInstruction
?
from componentarrays.jl.
@bgroenks96 So I guess I didn't need the special instruction thing after all. I tried this out and it seems to work for me, but I've just tested it on pretty simple stuff so far. Try adding this and running your code:
using ComponentArrays
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal, value, deriv, tape
const TrackedComponentArray{V, D, N, DA, N, A, Ax} = TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA}
maybe_tracked_array(val::AbstractArray, der, t) = TrackedArray(val, der, t)
maybe_tracked_array(val, der, t) = TrackedReal(val, der, t)
function Base.getindex(t::TrackedComponentArray, inds::Union{Symbol, Val}...)
return maybe_tracked_array(value(t)[inds...], deriv(t)[inds...], tape(t))
end
function Base.getproperty(t::TrackedComponentArray, s::Symbol)
if s in (:value, :deriv, :tape)
return getfield(t, s)
else
return maybe_tracked_array(getproperty(value(t), s), getproperty(deriv(t), s), tape(t))
end
end
from componentarrays.jl.
I'll do a little more testing and push an update soon.
from componentarrays.jl.
Hmm, I'm still getting the Type TrackedArray has no field ...
error. Am I supposed to change how I use it?
from componentarrays.jl.
Wait nevermind, that was an urelated problem. Here's the real error:
MethodError: getindex(::TrackedArray{Float64,Float64,0,Base.ReshapedArray{Float64,0,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{}},Base.ReshapedArray{Float64,0,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{}}}) is ambiguous. Candidates:
getindex(t::TrackedArray, i::AbstractRange...) in ReverseDiff at /home/brian/.julia/packages/ReverseDiff/iHmB4/src/tracked.jl:299
getindex(t::TrackedArray, i::Colon...) in ReverseDiff at /home/brian/.julia/packages/ReverseDiff/iHmB4/src/tracked.jl:299
getindex(t::TrackedArray, i::Int64...) in ReverseDiff at /home/brian/.julia/packages/ReverseDiff/iHmB4/src/tracked.jl:338
getindex(t::TrackedArray, i::Union{Colon, AbstractRange}...) in ReverseDiff at /home/brian/.julia/packages/ReverseDiff/iHmB4/src/tracked.jl:299
Possible fix, define
getindex(::TrackedArray)
isassigned(::TrackedArray{Float64,Float64,0,Base.ReshapedArray{Float64,0,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{}},Base.ReshapedArray{Float64,0,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{}}}) at abstractarray.jl:408
...
And it seems to occur because somehow my sub-array in ComponentVector
becomes a TrackedReal
?
TrackedReal{Float64,Float64,TrackedArray{Float64,Float64,1,ComponentVector{Float64,SubArray...},ComponentVector{Float64,SubArray...}}}[TrackedReal<4L7>(0.0, 0.0, 4rg, 1, LLy)]
I don't know, could be an issue in my code as well. Maybe I should just simplify down to a more basic example (like not involving sensitivities).
from componentarrays.jl.
Also this problem apparently also exists for DEDataArray
because getproperty
fails when it gets turned into TrackedArray
@ChrisRackauckas
from componentarrays.jl.
@jonniedie I reduced to a simpler example (just one evaluation of the ODEFunction
) and it works! Same result as ForwardDiff
. From what I can tell right now (I'll need to dig in more later), the further issues aren't necessarily ComponentArray
's responsibility, so I think you've done your part :)
Thanks for working on this! Sorry I couldn't be of more help...
from componentarrays.jl.
Related Issues (20)
- error on empty array of named tuples
- vcat ambiguity with SparseArrays method HOT 2
- add missing LinearAlgebra methods on GPU
- Problem with `Zygote.hessian` HOT 1
- `lmul!` fails with Component Array on GPU
- Incorrect gradient type. HOT 2
- Scalar indexing on GPU when computing `Zygote.gradient` of `dot(x::CA, x::CA)`
- Method ambiguities reported by Aqua
- getting "Only homogeneous arrays are allowed" error for Vector{SVector} HOT 4
- How to get index range of subarray? HOT 3
- Extend KeepIndex to Vector indices HOT 1
- Nine broken tests for Test Summary: | Pass Broken Total Time Broadcasting | 30 9 39 7.2s HOT 1
- The ComponentArray type that can specify the properties of the ComponentArray object
- Indexing `ComponentMatrix` with `FlatAxis` components HOT 4
- `Diagonal(ComponentArray)` scalar indexing error HOT 1
- ComponentVectors are not stack-able HOT 10
- ComponentArray errors with `findall`
- Can't index component array with shaped components with subset of keys HOT 2
- FR: Preserve CA-ness when indexing component matrix with shaped components HOT 1
- Get rid of `ComponentMatrix` and higher-order `ComponentArray`s? HOT 12
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from componentarrays.jl.