Giter Club home page Giter Club logo

Comments (16)

jonniedie avatar jonniedie commented on May 20, 2024 1

Zygote also works. But yeah, ReverseDiff is one I should also get working because it’s still used in a few places. I’m not sure if ReverseDiff uses ChainRules, but I’ve also been meaning to add that so I can cover more bases. I’ll try to get to this tonight.

from componentarrays.jl.

jonniedie avatar jonniedie commented on May 20, 2024

This is looking a little more difficult than I was expecting it to be. Supposedly there are plans to migrate ReverseDiff to use ChainRules, so I think it would probably be best to wait for that. In the meantime, doing a custom gradient with Zygote like gradsZ = Zygote.gradient(F, ca)[1].x is probably the way to go for things like Optim.

from componentarrays.jl.

bgroenks96 avatar bgroenks96 commented on May 20, 2024

What's the status of this? Does it work now? Last time I checked, ReverseDiff seemed to support ChainRules, though admittedly I'm still a bit hazy on how that all works.

from componentarrays.jl.

jonniedie avatar jonniedie commented on May 20, 2024

Well the good news is I updated my autodiff stuff to use ChainRules instead of Zygote directly. The bad news is ReverseDiff still doesn't use ChainRules. I'm looking into what it's going to take to make this work, but the ReverseDiff docs don't give me a ton to work with for figuring out how to make pullback rules. The only thing I can think of is overloading each derivative function (gradient, gradient!, jacobian, jacobian!, ...) for getproperty and getindex of a ReverseDiff.TrackedArray{V,D,N,<:ComponentArray,DA} where {V,D,N,DA}. If that's the case, I may need some help in the form of a PR or it will be a while until I can get to it.

from componentarrays.jl.

bgroenks96 avatar bgroenks96 commented on May 20, 2024

See #67

from componentarrays.jl.

bgroenks96 avatar bgroenks96 commented on May 20, 2024

Yeah so I kind of need this, possibly soon. What can I do to help?

from componentarrays.jl.

bgroenks96 avatar bgroenks96 commented on May 20, 2024

Although I actually don't necessarily need my parameter array to be a ComponentArray. I guess I should check if non-tracked arrays work with ComponentArrays.

from componentarrays.jl.

jonniedie avatar jonniedie commented on May 20, 2024

It's looking like we might be able to define getindex for a tracked ComponentArray and a Symbol or Val taking some inspiration from here

using ComponentArrays, ReverseDiff
using ReverseDiff: TrackedArray, value, deriv, tape, SpecialInstruction, record! 

const TrackedComponentArray{V, D, N, DA, N, A, Ax} = TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA}

function Base.getindex(t::TrackedComponentArray, ind::Union{Symbol, Val})
    tp = tape(t)
    out = TrackedArray(value(t)[ind], deriv(t)[ind], tp)
    record!(tp, SpecialInstruction, (getindex, Val(:generic)), (t, ind), out)
    return out
end

This is enough to get the tracking to work through the getindex operation (getproperty is probably not a good idea because TrackedArray needs to be able to access its fields with that), but actually taking derivatives, gradients, jacobians, etc. requires reaching into the special_reverse_exec! code here. Or maybe just making a separate dispatch for something other than Val(:generic) and patch in from there. I haven't had luck getting that working yet, though.

from componentarrays.jl.

bgroenks96 avatar bgroenks96 commented on May 20, 2024

So without getproperty we would only be able to do carray[:var] right?

from componentarrays.jl.

bgroenks96 avatar bgroenks96 commented on May 20, 2024

I guess I don't really know what special_reverse_exec is doing. What is SpecialInstruction?

from componentarrays.jl.

jonniedie avatar jonniedie commented on May 20, 2024

@bgroenks96 So I guess I didn't need the special instruction thing after all. I tried this out and it seems to work for me, but I've just tested it on pretty simple stuff so far. Try adding this and running your code:

using ComponentArrays
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal, value, deriv, tape

const TrackedComponentArray{V, D, N, DA, N, A, Ax} = TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA}

maybe_tracked_array(val::AbstractArray, der, t) = TrackedArray(val, der, t)
maybe_tracked_array(val, der, t) = TrackedReal(val, der, t)

function Base.getindex(t::TrackedComponentArray, inds::Union{Symbol, Val}...)
    return maybe_tracked_array(value(t)[inds...], deriv(t)[inds...], tape(t))
end

function Base.getproperty(t::TrackedComponentArray, s::Symbol)
    if s in (:value, :deriv, :tape)
        return getfield(t, s)
    else
        return maybe_tracked_array(getproperty(value(t), s), getproperty(deriv(t), s), tape(t))
    end
end

from componentarrays.jl.

jonniedie avatar jonniedie commented on May 20, 2024

I'll do a little more testing and push an update soon.

from componentarrays.jl.

bgroenks96 avatar bgroenks96 commented on May 20, 2024

Hmm, I'm still getting the Type TrackedArray has no field ... error. Am I supposed to change how I use it?

from componentarrays.jl.

bgroenks96 avatar bgroenks96 commented on May 20, 2024

Wait nevermind, that was an urelated problem. Here's the real error:

MethodError: getindex(::TrackedArray{Float64,Float64,0,Base.ReshapedArray{Float64,0,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{}},Base.ReshapedArray{Float64,0,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{}}}) is ambiguous. Candidates:
  getindex(t::TrackedArray, i::AbstractRange...) in ReverseDiff at /home/brian/.julia/packages/ReverseDiff/iHmB4/src/tracked.jl:299
  getindex(t::TrackedArray, i::Colon...) in ReverseDiff at /home/brian/.julia/packages/ReverseDiff/iHmB4/src/tracked.jl:299
  getindex(t::TrackedArray, i::Int64...) in ReverseDiff at /home/brian/.julia/packages/ReverseDiff/iHmB4/src/tracked.jl:338
  getindex(t::TrackedArray, i::Union{Colon, AbstractRange}...) in ReverseDiff at /home/brian/.julia/packages/ReverseDiff/iHmB4/src/tracked.jl:299
Possible fix, define
  getindex(::TrackedArray)
isassigned(::TrackedArray{Float64,Float64,0,Base.ReshapedArray{Float64,0,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{}},Base.ReshapedArray{Float64,0,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{}}}) at abstractarray.jl:408
...

And it seems to occur because somehow my sub-array in ComponentVector becomes a TrackedReal?

TrackedReal{Float64,Float64,TrackedArray{Float64,Float64,1,ComponentVector{Float64,SubArray...},ComponentVector{Float64,SubArray...}}}[TrackedReal<4L7>(0.0, 0.0, 4rg, 1, LLy)]

I don't know, could be an issue in my code as well. Maybe I should just simplify down to a more basic example (like not involving sensitivities).

from componentarrays.jl.

bgroenks96 avatar bgroenks96 commented on May 20, 2024

Also this problem apparently also exists for DEDataArray because getproperty fails when it gets turned into TrackedArray @ChrisRackauckas

from componentarrays.jl.

bgroenks96 avatar bgroenks96 commented on May 20, 2024

@jonniedie I reduced to a simpler example (just one evaluation of the ODEFunction) and it works! Same result as ForwardDiff. From what I can tell right now (I'll need to dig in more later), the further issues aren't necessarily ComponentArray's responsibility, so I think you've done your part :)

Thanks for working on this! Sorry I couldn't be of more help...

from componentarrays.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.