Giter Club home page Giter Club logo

sciml / diffeqflux.jl Goto Github PK

View Code? Open in Web Editor NEW
852.0 32.0 154.0 134.28 MB

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods

Home Page: https://docs.sciml.ai/DiffEqFlux/stable

License: MIT License

Julia 100.00%
neural-ode neural-sde neural-pde neural-dde neural-differential-equations stiff-ode ordinary-differential-equations stochastic-differential-equations delay-differential-equations partial-differential-equations

diffeqflux.jl's Introduction

DiffEqFlux.jl

Join the chat at https://julialang.zulipchat.com #sciml-bridged Global Docs

codecov Build Status Build status

ColPrac: Contributor's Guide on Collaborative Practices for Community Packages SciML Code Style

DiffEq(For)Lux.jl (aka DiffEqFlux.jl) fuses the world of differential equations with machine learning by helping users put diffeq solvers into neural networks. This package utilizes DifferentialEquations.jl, and Lux.jl as its building blocks to support research in Scientific Machine Learning, specifically neural differential equations to add physical information into traditional machine learning.

Note

We maintain backwards compatibility with Flux.jl via FromFluxAdaptor()

Tutorials and Documentation

For information on using the package, see the stable documentation. Use the in-development documentation for the version of the documentation, which contains the unreleased features.

Problem Domain

DiffEqFlux.jl is for implicit layer machine learning. DiffEqFlux.jl provides architectures which match the interfaces of machine learning libraries such as Flux.jl and Lux.jl to make it easy to build continuous-time machine learning layers into larger machine learning applications.

The following layer functions exist:

  • Neural Ordinary Differential Equations (Neural ODEs)
  • Collocation-Based Neural ODEs (Neural ODEs without a solver, by far the fastest way!)
  • Multiple Shooting Neural Ordinary Differential Equations
  • Neural Stochastic Differential Equations (Neural SDEs)
  • Neural Differential-Algebraic Equations (Neural DAEs)
  • Neural Delay Differential Equations (Neural DDEs)
  • Augmented Neural ODEs
  • Hamiltonian Neural Networks (with specialized second order and symplectic integrators)
  • Continuous Normalizing Flows (CNF) and FFJORD

with high order, adaptive, implicit, GPU-accelerated, Newton-Krylov, etc. methods. For examples, please refer to the release blog post. Additional demonstrations, like neural PDEs and neural jump SDEs, can be found in this blog post (among many others!).

Do not limit yourself to the current neuralization. With this package, you can explore various ways to integrate the two methodologies:

  • Neural networks can be defined where the “activations” are nonlinear functions described by differential equations
  • Neural networks can be defined where some layers are ODE solves
  • ODEs can be defined where some terms are neural networks
  • Cost functions on ODEs can define neural networks

Flux ODE Training Animation

Breaking Changes in v3

  • Flux dependency is dropped. If a non Lux AbstractExplicitLayer is passed we try to automatically convert it to a Lux model with FromFluxAdaptor()(model).
  • Flux is no longer re-exported from DiffEqFlux. Instead we reexport Lux.
  • NeuralDAE now allows an optional du0 as input.
  • TensorLayer is now a Lux Neural Network.
  • APIs for quite a few layer constructions have changed. Please refer to the updated documentation for more details.

diffeqflux.jl's People

Contributors

abhigupta768 avatar abhishek-1bhatt avatar adrhill avatar arnostrouwen avatar avik-pal avatar baggepinnen avatar chrisrackauckas avatar christopher-dg avatar collinwarner avatar d-netto avatar dependabot[bot] avatar devmotion avatar dhairyalgandhi avatar emmanuel-r8 avatar erikqqy avatar frankschae avatar github-actions[bot] avatar jeremyfongsp avatar jessebett avatar kanav99 avatar metanoid avatar mikeinnes avatar mkg33 avatar piotrsokol avatar prbzrg avatar ranjanan avatar sathvikbhagavan avatar thazhemadam avatar vaibhavdixit02 avatar yingboma avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

diffeqflux.jl's Issues

How to mix "normal" ODEs and Neuronal ODEs?

Thanks for this beautiful package!

From the Readme and the blog I understood i) how to use a "normal" ODE as Flux layer (diffeq_rd), and ii) how to use a Flux layer to define an ODE (neural_ode).

I was wondering if already an API exists to mix both approaches. For example to define something like this:

function dudt(u)
    x, y = u
    ann = Chain(Dense(2,10,tanh),  Dense(10,1))
    du[1] = ann(u)                     
    du[2] = -2.0*y + 1.1*x*y   
end

pkg issue

Hi Chris,
great work! Excellent package and tutorial.
I am a newbie in Julia but I appreciated your new take on the neuro-ode that I had to try your code.
While attempting to clone, precompile, and test the pkg, which I did with “add https://github.com/JuliaDiffEq/DiffEqFlux.jl/” & “test DiffEqFlux”, I got the following error "ERROR: LoadError: LoadError: UndefVarError: adjoint_sensitivities_u0 not defined." Not sure if I got anything wrong or if this is a real issue.
Cheers!

Adjoint + Batching + GPU

using OrdinaryDiffEq, StochasticDiffEq, Flux, DiffEqFlux
using Test
using CuArrays

xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.])) |> gpu
tspan = Float32.((0.0f0,25.0f0))
dudt = Chain(Dense(2,50,tanh),Dense(50,2)) |> gpu

CuArrays.allowscalar(false)
neural_ode(dudt,xs,tspan,Vern9(lazy=false),save_everystep=false,save_start=false)
neural_ode(dudt,xs,tspan,Vern9(lazy=false),saveat=0.1)
neural_ode_rd(dudt,xs,tspan,BS3(),saveat=0.1)

Tracker.zero_grad!(dudt[1].W.grad)
Flux.back!(sum(neural_ode(dudt,xs,tspan,BS3(),save_everystep=false,save_start=false)))
@test ! iszero(Tracker.grad(dudt[1].W))

Continuous normalizing flows

using DifferentialEquations
using Distributions
using Flux, DiffEqFlux, ForwardDiff
using Flux.Tracker


function f(z, p)
  α, β = p
  tanh.(α.*z .+ β)
end

u0 = [0.0, 0.0]
tspan = (0.0, 10.0)
function cnf(du,u,p,t)
  z, logpz = u
  α, β = p
  du[1] = f(z, p)
  du[2] = -sum(ForwardDiff.jacobian((z)->f(z, p), [z]))
end
prob = ODEProblem(cnf,u0,tspan,nothing)

p = param([0.0, 0.0]) # Initial Parameter Vector
params = Params([p])

function predict_adjoint(x)
    diffeq_adjoint(p,prob,Tsit5(),u0=[x,0.0],
                   saveat=0.0:0.1:10.0,
                   sensealg=DiffEqFlux.SensitivityAlg(quad=false,
                                backsolve=true,autojacvec=false))
end

function loss_adjoint(xs)
    pz = Normal(0.0, 1.0)
    preds = [predict_adjoint(x)[:,end] for x in xs]
    z = [pred[1] for pred in preds] # TODO better slicing
    delta_logp = [pred[2] for pred in preds]

    logpz = logpdf.(pz, z)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end


opt = ADAM(0.1)

raw_data = [[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);

Flux.train!(loss_adjoint, params, data, opt)



# check whether it looks standard normal
using Plots

preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];

histogram([p[1].data for p in preds])

Stack Overflow error

I am trying to optimize the parameters of a ODE system...it looks something like this:

function renshaw_traub(du, u , pm, t)
   #Membrane voltage functions
   du[1] = ((1.0 / pm[7]) * (  u[11]
                            - ((pm[10] * u[1])
                            + (pm[1] * u[4]^3 * u[5] * (u[1] - 115.0))
                            + (pm[3] * u[6]^4 * (u[1] + 5.0))
                            + (pm[5] * u[7]^2 * (u[1] + 5.0))
                            + (7.0 * pm[13] * (u[1] - u[2]))
                            + (1.2 * pm[14] * (u[1] - u[3])))));

   du[2] = ((1.0 / pm[8]) *  - ((pm[11] * u[2])
                                + (1.167 * pm[13] * (u[2] - u[1]))));

   du[3] = ((1.0 / pm[9]) *  - ((pm[12] * u[3])
                                + (pm[2] * u[8]^3 * u[9] * (u[3] - 115.0))
                                + (pm[4] * u[10]^4 * (u[3] + 5.0))
                                + (5.73 * (u[3] - u[1]))));

   #Membrane channel state variable functions for soma
   du[4] = (((pm[15] * (pm[16] - u[1]) / (exp((pm[17] - u[1]) / pm[18]) - 1.0)) * (1.0 - u[4]))
           - ((pm[19] * (u[1] - pm[20]) / (exp((u[1] - pm[21]) / pm[22])- 1.0)) * u[4]));

   du[5] = (((pm[23] * exp((pm[24] - u[1]) / pm[25])) * (1.0 - u[5]))
           - ((pm[26] / (exp((pm[27] - u[1]) / pm[28]) + 1.0)) * u[5]));

   du[6] = (((pm[29] * (pm[30] - u[1]) / (exp((pm[31] - u[1]) / pm[32]) - 1.0)) * (1.0 - u[6]))
           - ((pm[33] * exp((pm[34] - u[1]) / pm[35])) * u[6]));

   du[7] = (((pm[36] / (exp((pm[37] - u[1]) / pm[38]) + 1.0)) * (1.0 - u[7]))
           - (pm[6] * u[7] * 1e-5));

   #Membrane channel state variable functions for IS
   du[8] = (((pm[39] * (pm[40] - u[3]) / (exp((pm[41] - u[3]) / pm[42]) - 1.0)) * (1.0 - u[8]))
           - ((pm[43] * (u[3] - pm[44]) / (exp((u[3] - pm[45]) / pm[46])- 1.0)) * u[8]));

   du[9] = (((pm[47] * exp((pm[48] - u[3]) / pm[49])) * (1.0 - u[9]))
           - ((pm[50] / (exp((pm[51] - u[3]) / pm[52]) + 1.0)) * u[9]));

   du[10] = (((pm[53] * (pm[54] - u[3]) / (exp((pm[55] - u[3]) / pm[56]) - 1.0)) * (1.0 - u[10]))
            - ((pm[57] * exp((pm[58] - u[3]) / pm[59])) * u[10]));

   #Input current
   du[11] = -(((83.33 * 200.0 * (0.025 * (t - 40.0))^200) / ((t - 40.0) * ((0.025 * (t - 40.0))^200 + 1.0) ^ 2))
            + ((128.2 * 200.0 * (0.025 * (t - 220.0))^200) / ((t - 220.0) * ((0.025 * (t - 220.0))^200 + 1.0) ^ 2))
            + ((250.0 * 200.0 * (0.025 * (t - 400.0))^200) / ((t - 400.0) * ((0.025 * (t - 400.0))^200 + 1.0) ^ 2))
            + ((314.1 * 200.0 * (0.025 * (t - 580.0))^200) / ((t - 580.0) * ((0.025 * (t - 580.0))^200 + 1.0) ^ 2))
            + ((448.72 * 200.0 * (0.025 * (t - 760.0))^200) / ((t - 760.0) * ((0.025 * (t - 760.0))^200 + 1.0) ^ 2))
            + ((769.23 * 200.0 * (0.025 * (t - 940.0))^200) / ((t - 940.0) * ((0.025 * (t - 940.0))^200 + 1.0) ^ 2)));
end

The optimization code looks something like this:

px = renshaw_params();
ppx = param(px);
params = Flux.Params([ppx]);
u0 = renshaw_intializer(0.0, 0.0, 0.0, px[6]);
# u0 = renshaw_intializer_RP(-70.0);
prob = ODEProblem(renshaw_traub, [u0;1.3e-3], (0.0, 979.99), px);
alpha = 1e-5;

function objective()
    diffeq_rd(ppx, prob, OwrenZen5(), reltol = 1e-3, abstol = 1e-6,
                     maxiters = Int(1e8), saveat = 0.01)[1,:]
end

loss_func() = norm(Data - objective());

reps = Iterators.repeated((), 1000)
opt = NADAM(alpha);
cb = function ()
    display(loss_func());
    display(plot(solve(remake(prob, p = Flux.data(ppx)), OwrenZen5(),
            reltol = 1e-4, abstol = 1e-6, maxiters = Int(1e8), saveat = 0.01)[1,:],
            ylim=(0.0,100.0)));
end
cb()
Flux.train!(loss_func, params, reps, opt, cb = cb);

However, the code runs good and as expected for short "tspan"s...and generates StackOverflow errors if I run it for the full tspan...I need the full tspan for proper optimization...is there any workaround other than increasing the system StackSize or going for a more powerful system? I saw in a similar post that using Tracked Arrays can reduce the computational cost (https://github.com/FluxML/Flux.jl/issues/626). Does this apply here?

Stacktrace:
ERROR: LoadError: StackOverflowError: Stacktrace: [1] scan(::Tracker.Tracked{Float64}) at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back.jl:23 [2] foreach at ./abstractarray.jl:1920 [inlined] [3] scan(::Tracker.Call{getfield(Tracker, Symbol("##259#262")),Tuple{Tracker.Tracked{Float64},Tracker.Tracked{Float64}}}) at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back.jl:17 ... (the last 3 lines are repeated 21 more times) [67] scan(::Tracker.Tracked{Float64}) at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back.jl:23 [68] foreach at ./abstractarray.jl:1920 [inlined] [69] scan(::Tracker.Call{getfield(Tracker, Symbol("##259#262")),Tuple{Tracker.Tracked{Float64},Tracker.Tracked{Float64}}}) at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back.jl:17 (repeats 2 times) [70] scan(::Tracker.Tracked{Float64}) at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back.jl:23 [71] foreach at ./abstractarray.jl:1920 [inlined] ... (the last 3 lines are repeated 4 more times) [84] scan(::Tracker.Call{getfield(Tracker, Symbol("##149#150")){Tracker.TrackedReal{Float64}},Tuple{Tracker.Tracked{Float64}}}) at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back. jl:17 [85] scan(::Tracker.Tracked{Float64}) at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back.jl:23 [86] scan at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back.jl:30 [inlined] [87] #back!#15 at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back.jl:76 [inlined] [88] #back! at ./none:0 [inlined] [89] #back!#32 at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/lib/real.jl:16 [inlined] [90] back!(::Tracker.TrackedReal{Float64}) at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/lib/real.jl:14 [91] gradient_(::getfield(Flux.Optimise, Symbol("##14#20")){typeof(loss_func),Tuple{}}, ::Tracker.Params) at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back.jl:4 [92] #gradient#24(::Bool, ::typeof(Tracker.gradient), ::Function, ::Tracker.Params) at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back.jl:164 [93] gradient at /home/nirvik/.julia/packages/Tracker/JhqMQ/src/back.jl:164 [inlined] [94] macro expansion at /home/nirvik/.julia/packages/Flux/dkJUV/src/optimise/train.jl:71 [inlined] [95] macro expansion at /home/nirvik/.julia/packages/Juno/oLB1d/src/progress.jl:119 [inlined] [96] #train!#12(::getfield(Main, Symbol("##11#12")), ::typeof(Flux.Optimise.train!), ::Function, ::Tracker.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::NADAM) at /home/nirvik/.julia/packages/Flux/dkJUV/src/optimise/train.jl:69 [97] (::getfield(Flux.Optimise, Symbol("#kw##train!")))(::NamedTuple{(:cb,),Tuple{getfield(Main, Symbol("##11#12"))}}, ::typeof(Flux.Optimise.train!), ::Function, ::Tracker.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::NADAM) at ./none:0 in expression starting at /home/nirvik/Documents/neuronal_model_julia/rotation/renshaw_optimization/my_loss.jl:

Allocations using diffeq_rd()

I hace noticed that a considerable amount of allocation appear when using the function diffeq_rd for evaluating the ODE with the tracked parameters. In this example it is shown very clearly:

using Flux, DiffEqFlux, DifferentialEquations, BenchmarkTools

const μ     = 2.0451149878572323e6
const wt    = 0.0003766355744484957

function coordSpher(X::Vector)
    x, y, z, vx, vy, vz = X
    r   = sqrt(x^2+y^2+z^2)
    ϕ   = asin(z/r)
    λ   = atan(y,x)
    return r,ϕ,λ
end
function ast(f,x,p,t,μ=μ,wt=wt) # x : estado en coordenadas cartesianas
    r, ϕ, λ = coordSpher(x)
    J2,J22  = p
    Rm      = 1.0
    Ur      = -μ/r^4*( r^2 + J2*3/2*Rm^2*(1-3*sin(ϕ)^2)-9*J22*Rm^2*cos(ϕ)^2*cos(2λ) )
    Up      = -μ/r^3*6*Rm^2*cos(ϕ)*sin(ϕ)*(J2/2-J22*cos(2λ))
    Ul      = μ/r^3*6*J22*Rm^2*cos(ϕ)^2*sin(2*λ)
    Gx      = Ur*cos(ϕ)*cos(λ) + Ul/r*(-sin(ϕ)*cos(λ)) +Up/r/cos(ϕ)*-sin(λ)
    Gy      = Ur*cos(ϕ)*sin(λ) + Ul/r*(-sin(ϕ)*sin(λ)) +Up/r/cos(ϕ)*-cos(λ)
    Gz      = Ur*sin(ϕ) + Ul/r*cos(ϕ)
    f[1] = x[4]
    f[2] = x[5]
    f[3] = x[6]
    f[4] = Gx + 2*x[5]*wt    + x[1]*wt^2
    f[5] = Gy - 2*x[4]*wt   + x[2]*wt^2
    f[6] = Gz
end

u0 = [-18105.392, -88924.165, 51340.390, -28.859, 6.012, 0.465]
tspan = (0.0,2e4)
p = [2.218190645002624e8, -1.4619343719968584e8] # true parameters

prob = ODEProblem(ast,u0,tspan,p)

data_sol = solve(prob,Tsit5(),saveat=1000.0,abstol=1e-7,reltol=1e-6)

pfd = param(p)  # using the same paramters just to check the allocations in same conditions
function predict_rd() # Our 1-layer neural network
  diffeq_rd(pfd,prob,Tsit5(),saveat=1000.0,abstol=1e-7,reltol=1e-6)
end
loss_rd() = sum(abs2,predict_rd()-data_sol) # loss function

@benchmark solve(prob,Tsit5(),saveat=1000.0,abstol=1e-7,reltol=1e-6)
@benchmark predict_rd()

Using the solve of DifferentialEquations yields:

BenchmarkTools.Trial: 
  memory estimate:  10.50 KiB
  allocs estimate:  119
  --------------
  minimum time:     79.711 μs (0.00% GC)
  median time:      85.053 μs (0.00% GC)
  mean time:        91.088 μs (1.39% GC)
  maximum time:     3.673 ms (96.39% GC)
  --------------
  samples:          10000
  evals/sample:     1

While the diffeq_rd on the other hand uses almost 10 MiB of memory and is much slower:

BenchmarkTools.Trial: 
  memory estimate:  9.18 MiB
  allocs estimate:  291317
  --------------
  minimum time:     1.914 ms (0.00% GC)
  median time:      2.385 ms (0.00% GC)
  mean time:        5.988 ms (61.87% GC)
  maximum time:     22.093 ms (84.58% GC)
  --------------
  samples:          834
  evals/sample:     1

I thought it was a problem of the function predict_rd() using prob from the global scope, but even feeding it to the function as input gives the same result, so I don't know whats happening. For what I understand of the source code of diffeq_rd it basically remakes the ODE with the tracked parameters and applies solve, so maybe the inefficiency comes from those tracked parameters?

Is there a solution for this or is it a limitation of how Flux works?

BoundsError when taking gradient

using Flux, DiffEqFlux, OrdinaryDiffEq,  StatsBase, RecursiveArrayTools
using Flux: onehotbatch
using MLDatasets:MNIST
using Base.Iterators: repeated,partition

batch_size=10

train_x, train_y = MNIST.traindata();
test_x, test_y = MNIST.testdata();
train_y_hot = onehotbatch(train_y,0:9);

train_data = [(reshape(train_x[:,:,i],(28,28,1,batch_size)),train_y_hot[:,i]) for i in partition(1:60_000,batch_size)];

downsample = Chain(
                   Conv((3,3),1=>32,stride=1),
                   BatchNorm(32,relu),
                   Conv((4,4),32=>32,stride=2,pad=1),
                   BatchNorm(32,relu),
                   Conv((4,4),32=>32,stride=2,pad=1)
                  )
dudt = Chain(
          BatchNorm(32,relu),
          Conv((3,3),32=>32,stride =1, pad=1),
          BatchNorm(32),
          Conv((3,3),32=>32,stride =1, pad=1),
          BatchNorm(32)
          )
classify = Chain(
                 BatchNorm(32,relu), 
                 MeanPool((6,6)),
                 x->view(x,1,1,:,:), #Need More General Flatten Here
                 Dense(32,10)
                )

ps = Flux.params(dudt)

function n_ode(batch_u0, batch_t)
    neural_ode(dudt,batch_u0,(0.,25.),Tsit5(),
               save_start=false,
               saveat=batch_t,      # Ugly way to only get sol[end] 
               reltol=1e-3,abstol=1e-3)
end

model = Chain(downsample,u->n_ode(u,[25.])[:,:,:,:,end],classify) # Further ugliness getting sol[end]

loss(x,y) = Flux.mse(model(x),y)

loss(train_data[1]...) # Works and is tracked
# This fails with neural_ode_rd with T undefined error

opt = ADAM(0.1)
Ps = params(model,ps)

Tracker.gradient(()->loss(xt,yt),Ps) # Fails with BoundsError
#ERROR: BoundsError: attempt to access 6×6×32×10×1 Array{Float32,5} at index [Base.Slice(Base.OneTo(6)), 1]

Flux.train!(loss, Ps, train_data, opt, cb = ()) # Fails with BoundsError

The loss produces a tracked scalar, but taking gradients causes a BoundsError.

Further, if I use neural_ode_rd instead it doesn't even evaluate the forward pass correctly, as there's a T undefined error.

DiffEqFlux 0.6.0, 0.5.0 and 0.5.1 do not compile on julia 1.0

I have found that under Julia 1.0 DiffEqFlux does not compile for versions 0.6.0, 0.5.1 and 0.5.0.
The combinations Julia 1.2.0 + DiffEqFlux 0.6.0 and Julia 1.0 + DiffEqFlux 0.4.0 work just fine.

Tested on Ubuntu 14.04.6 LTS and Ubuntu 18.04.2 LTS. Both systems give the same results.

Tested on a clean environment with only DiffEqFlux as installed package.

Error dump:

Precompiling DiffEqFlux
[ Info: Precompiling DiffEqFlux [aae7a2af-3d4f-5e19-a356-7da93b79d9d0]
ERROR: LoadError: LoadError: syntax: invalid function name "Morris <: GSAMethod"
Stacktrace:
 [1] include at ./boot.jl:317 [inlined]
 [2] include_relative(::Module, ::String) at ./loading.jl:1038
 [3] include at ./sysimg.jl:29 [inlined]
 [4] include(::String) at /home/jos/.julia/packages/DiffEqSensitivity/xlGSs/src/DiffEqSensitivity.jl:3
 [5] top-level scope at none:0
 [6] include at ./boot.jl:317 [inlined]
 [7] include_relative(::Module, ::String) at ./loading.jl:1038
 [8] include(::Module, ::String) at ./sysimg.jl:29
 [9] top-level scope at none:2
 [10] eval at ./boot.jl:319 [inlined]
 [11] eval(::Expr) at ./client.jl:389
 [12] top-level scope at ./none:3
in expression starting at /home/jos/.julia/packages/DiffEqSensitivity/xlGSs/src/morris_sensitivity.jl:1
in expression starting at /home/jos/.julia/packages/DiffEqSensitivity/xlGSs/src/DiffEqSensitivity.jl:17
ERROR: LoadError: Failed to precompile DiffEqSensitivity [41bf760c-e81c-5289-8e54-58b1f1f8abe2] to /home/jos/.julia/compiled/v1.0/DiffEqSensitivity/02xYn.ji.
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] macro expansion at ./logging.jl:313 [inlined]
 [3] compilecache(::Base.PkgId, ::String) at ./loading.jl:1184
 [4] _require(::Base.PkgId) at ./logging.jl:311
 [5] require(::Base.PkgId) at ./loading.jl:852
 [6] macro expansion at ./logging.jl:311 [inlined]
 [7] require(::Module, ::Symbol) at ./loading.jl:834
 [8] include at ./boot.jl:317 [inlined]
 [9] include_relative(::Module, ::String) at ./loading.jl:1038
 [10] include(::Module, ::String) at ./sysimg.jl:29
 [11] top-level scope at none:2
 [12] eval at ./boot.jl:319 [inlined]
 [13] eval(::Expr) at ./client.jl:389
 [14] top-level scope at ./none:3
in expression starting at /home/jos/.julia/packages/DiffEqFlux/Cufen/src/DiffEqFlux.jl:3
ERROR: Failed to precompile DiffEqFlux [aae7a2af-3d4f-5e19-a356-7da93b79d9d0] to /home/jos/.julia/compiled/v1.0/DiffEqFlux/BdO4p.ji.
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] macro expansion at ./logging.jl:313 [inlined]
 [3] compilecache(::Base.PkgId, ::String) at ./loading.jl:1184
 [4] precompile(::Pkg.Types.Context) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/Pkg/src/API.jl:489
 [5] do_precompile!(::Dict{Symbol,Any}, ::Array{String,1}, ::Dict{Symbol,Any}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/Pkg/src/REPLMode.jl:586
 [6] #invokelatest#1(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::Any, ::Any, ::Vararg{Any,N} where N) at ./essentials.jl:686
 [7] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N) at ./essentials.jl:685
 [8] do_cmd!(::Pkg.REPLMode.PkgCommand, ::REPL.LineEditREPL) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/Pkg/src/REPLMode.jl:542
 [9] #do_cmd#30(::Bool, ::Function, ::REPL.LineEditREPL, ::String) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/Pkg/src/REPLMode.jl:507
 [10] do_cmd at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/Pkg/src/REPLMode.jl:503 [inlined]
 [11] (::getfield(Pkg.REPLMode, Symbol("##41#44")){REPL.LineEditREPL,REPL.LineEdit.Prompt})(::REPL.LineEdit.MIState, ::Base.GenericIOBuffer{Array{UInt8,1}}, ::Bool) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/Pkg/src/REPLMode.jl:842
 [12] #invokelatest#1 at ./essentials.jl:686 [inlined]
 [13] invokelatest at ./essentials.jl:685 [inlined]
 [14] run_interface(::REPL.Terminals.TextTerminal, ::REPL.LineEdit.ModalInterface, ::REPL.LineEdit.MIState) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/REPL/src/LineEdit.jl:2261
 [15] run_frontend(::REPL.LineEditREPL, ::REPL.REPLBackendRef) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/REPL/src/REPL.jl:1029
 [16] run_repl(::REPL.AbstractREPL, ::Any) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/REPL/src/REPL.jl:191
 [17] (::getfield(Base, Symbol("##720#722")){Bool,Bool,Bool,Bool})(::Module) at ./logging.jl:311
 [18] #invokelatest#1 at ./essentials.jl:686 [inlined]
 [19] invokelatest at ./essentials.jl:685 [inlined]
 [20] macro expansion at ./logging.jl:308 [inlined]
 [21] run_main_repl(::Bool, ::Bool, ::Bool, ::Bool, ::Bool) at ./client.jl:330
 [22] exec_options(::Base.JLOptions) at ./client.jl:242
 [23] _start() at ./client.jl:421

DiffEqFlux could handle infinite loss better.

I've been (miss)using DiffEqFlux for some project with a differential equation whose parameter space has solutions with finite escape time.

Loss is infinite in top-level scope at base/none in at base/none in #train!#12 at Flux/qXNjB/src/optimise/train.jl:69 in macro expansion at Juno/TfNYn/src/progress.jl:124 in macro expansion at Flux/qXNjB/src/optimise/train.jl:71 in gradient at Tracker/RRYy6/src/back.jl:164 in #gradient#24 at Tracker/RRYy6/src/back.jl:164 in gradient_ at Tracker/RRYy6/src/back.jl:98 in losscheck at Tracker/RRYy6/src/back.jl:154

It currently throws an error under these conditions. The issue came up in this tread of mine: https://discourse.julialang.org/t/tracking-initial-condition-to-optimize-starting-value-too-diffeqflux/26596

A solution in this case was proposed:
Since ADAM is inherently stochastic, it could have functionality to discard Infs and try a different randomized points to try and not hit the bad area. Machine learning problems generally don’t have this bifurcation behavior so it doesn’t show up as much, so I think it was just overlooked.

Could you look into it solving this in general?

Problem combining diffeq_rd with conv layers

Hello,

First of all, thank you for this cool project!

I ve been having some issues combining the DE layer with Convolution layers. With the Output from a DE layer, the Conv layer seems to force a conversion that is not possible.

Example following the tutorial/blog (this application is of course nonsense for combining DE Layers and Conv layers, but it is a minimal example)

using DifferentialEquations
using Flux, DiffEqFlux

function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

u0 = Float32.([1.0,1.0])
tspan = (Float32(0.0),Float32(0.1))
p = Float32.([1.5,1.0,3.0,1.0])
prob = ODEProblem(lotka_volterra,u0,tspan,p)

struct ODEBlock{S,T,F,V}
    prob::S
    p::T
    dt::F
    u0::V
end

Flux.@treelike ODEBlock

(m::ODEBlock)(x) = Array(diffeq_rd(m.p, m.prob, Tsit5(), saveat=m.dt, u0=x))

# first ode block, then reshape to WHCN form, then 2 conv layers
model = Chain(ODEBlock(prob, param(p), Float32(0.1), u0), x -> reshape(x,(:,1,1,1)), Conv((2,1),1=>2), Conv((2,1),2=>1))

model(Float32.([0.3, 0.5]))

I was making sure that everything is Float32 because the Conv weights are Float32, too.

This invokes the error:

ERROR: LoadError: MethodError: no method matching Float32(::Flux.Tracker.TrackedReal{Float32})
Closest candidates are:
  Float32(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:185
  Float32(::T<:Number) where T<:Number at boot.jl:725
  Float32(::Int8) at float.jl:60
  ...
Stacktrace:
 [1] _broadcast_getindex_evalf at ./broadcast.jl:574 [inlined]
 [2] _broadcast_getindex at ./broadcast.jl:547 [inlined]
 [3] getindex at ./broadcast.jl:507 [inlined]
 [4] copy at ./broadcast.jl:782 [inlined]
 [5] materialize at ./broadcast.jl:748 [inlined]
 [6] (::Conv{2,typeof(identity),TrackedArray{Float32,4,Array{Float32,4}},TrackedArray{Float32,1,Array{Float32,1}}})(::Array{Flux.Tracker.TrackedReal{Float32},2}) at /home/max/.julia/packages/Flux/8XpDt/src/layers/conv.jl:57
 [7] applychain(::Tuple{Conv{2,typeof(identity),TrackedArray{Float32,4,Array{Float32,4}},TrackedArray{Float32,1,Array{Float32,1}}},Conv{2,typeof(identity),TrackedArray{Float32,4,Array{Float32,4}},TrackedArray{Float32,1,Array{Float32,1}}}}, ::Array{Flux.Tracker.TrackedReal{Float32},2}) at /home/max/.julia/packages/Flux/8XpDt/src/layers/basic.jl:31 (repeats 2 times)
 [8] (::Chain{Tuple{ODEBlock{ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},TrackedArray{Float32,1,Array{Float32,1}},Float32,Array{Float32,1}},Conv{2,typeof(identity),TrackedArray{Float32,4,Array{Float32,4}},TrackedArray{Float32,1,Array{Float32,1}}},Conv{2,typeof(identity),TrackedArray{Float32,4,Array{Float32,4}},TrackedArray{Float32,1,Array{Float32,1}}}}})(::Array{Float32,1}) at /home/max/.julia/packages/Flux/8XpDt/src/layers/basic.jl:33
 [9] top-level scope at none:0

This could also very well be an issue with the way Conv() is written, but with Flux without DiffEqFlux I did not encounter this issue before.

Preprocessing Layer in Neural ODE causes Error (maybe Tracker related).

So, I tried to integrate a custom pre-processing layer into a Neural ODE and encountered an error that does not happen if the same model is trained as a regular ANN. Inserting very simple functions like x->x.^2 works, but I fail with more complicated ones. For the sake of simplicity here I wrote the layer just as a function and not as a struct, as this error already occurs in this case as well.

I tried inserting a Tracker.collect(...) here and there but it does not change something.

using Flux, DiffEqFlux, DifferentialEquations

function test_fun(x)
    out = similar(x, 5)
    out[1:3] .= x[1:3] .^ 2
    out[4] = x[1] ^ 5
    out[5] = x[2]*1.5
    return out
end

model = Chain(x -> test_fun(x), Dense(5, 3, tanh))
data = [[rand(3), rand(3)]]
loss(x,y) = Flux.mse(model(x),y)
Flux.@epochs 2 Flux.train!(loss, Flux.params(model), data, ADAM())

The code above works.

n_ode = x->neural_ode(model,x,(0.,0.2),Tsit5(),saveat=[0,0.1,0.2])
predict_ode() = n_ode(rand(3))
loss_n_ode() = Flux.mse(predict_ode(), rand(3))
Flux.@epochs 2 Flux.train!(loss_n_ode, Flux.params(model), Iterators.repeated((),10), ADAM())

This does not work. It yields:

ERROR: LoadError: MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})
Closest candidates are:
  Float64(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:194
  Float64(::T<:Number) where T<:Number at boot.jl:718
  Float64(::Int8) at float.jl:60
  ...
Stacktrace:
 [1] convert(::Type{Float64}, ::Tracker.TrackedReal{Float64}) at ./number.jl:7
 [2] setindex!(::Array{Float64,1}, ::Tracker.TrackedReal{Float64}, ::Int64) at ./array.jl:766
 [3] macro expansion at ./subarray.jl:295 [inlined]
 [4] macro expansion at ./simdloop.jl:77 [inlined]
 [5] copyto! at ./broadcast.jl:887 [inlined]
 [6] copyto! at ./broadcast.jl:842 [inlined]
 [7] materialize! at ./broadcast.jl:801 [inlined]
 [8] test_fun(::TrackedArray{Float64,1,Array{Float64,1}}) at test_stackoverflow.jl:5
...

Zero'd gradients in example

From @jessebett

using Flux, DiffEqFlux, OrdinaryDiffEq, Distributions
const tspan = (0.0f0,1.f0)
const RANGE = (-3.,3.)
const BS = 200

target(u) = u.^3

function gen_data(batchsize,target_fun)
    x = Float32.(rand(Uniform(RANGE...),batchsize)')
    return x,target(x)
end

dudt = Chain(Dense(1,20,relu),Dense(20,1))
function n_ode(u0)
    neural_ode(dudt,u0,tspan,Tsit5(),dense=false,save_everystep=false,save_start=false,save_end=true,reltol=1e-5,abstol=1e-12)
end
ps = Flux.params(dudt)

loss(x,y) = mean(abs.(n_ode(x)-y))

tx,ty = gen_data(BS,target)
grads=Tracker.gradient(()->loss(tx,ty),ps)
for g in grads
    @show g
end # This returns 0.0s, no gradients getting through

Type issues when training weights are matrices

First of all, thanks for publishing this very interesting package!

I am trying to use this to do some backprop through an ODE where my parameters are matrices rather than a collection of scalars, as in the example. I am having trouble wrangling the type system into getting this to work.

If I try the approach given by the code below

p = [a, b]
prob = ODEProblem(deriv, u0, (0.0, T), p)
# sol  = solve(prob, SSPRK83(), dt=dt, progress=true, abstol=1e-6, reltol=1e-6, progress_steps=100);
params = Flux.Params([param(a), param(b)]);

function predict_rd()
  Array(diffeq_rd(param([a, b]), prob, SSPRK83(), dt=dt, progress=true, abstol=1e-6, reltol=1e-6, progress_steps=100))
end

loss_rd() = 1/sum(abs2, x for x in predict_rd()[400, 70, 1, :])

cb =  display(loss_rd())

I get the error

ERROR: MethodError: no method matching zero(::Type{Array{Float64,2}})
Closest candidates are:
  zero(::Type{LibGit2.GitHash}) at /Users/osx/buildbot/slave/package_osx64/build/usr/share/julia/stdlib/v1.1/LibGit2/src/oid.jl:220
  zero(::Type{Pkg.Resolve.VersionWeights.VersionWeight}) at /Users/osx/buildbot/slave/package_osx64/build/usr/share/julia/stdlib/v1.1/Pkg/src/resolve/VersionWeights.jl:19
  zero(::Type{Pkg.Resolve.MaxSum.FieldValues.FieldValue}) at /Users/osx/buildbot/slave/package_osx64/build/usr/share/julia/stdlib/v1.1/Pkg/src/resolve/FieldValues.jl:44
  ...
Stacktrace:
 [1] zero(::Array{Array{Float64,2},1}) at ./abstractarray.jl:849
 [2] TrackedArray(::Array{Array{Float64,2},1}) at /Users/iwill/.julia/packages/Flux/GBuXv/src/tracker/lib/array.jl:32
 [3] param(::Array{Array{Float64,2},1}) at /Users/iwill/.julia/packages/Flux/GBuXv/src/tracker/Tracker.jl:103
 [4] predict_rd() at ./none:2
 [5] loss_rd() at ./none:1
 [6] top-level scope at none:0

I think this must have something to do with my wrapping of param().

argtail type issue with backprop

using Flux, DiffEqFlux, DifferentialEquations, Plots

function func(du, u, p, t)
  a, b, c, d = p
  du[1] =  exp(a-b/8.3145/t)*(u[1])^c*(1.0-u[1])^d
end

u0 = [0.05]
tspan = (300.0, 1200.0)
p = [20.0, 150000.0, 1.0, 1.0]
prob = ODEProblem(func, u0, tspan, p)

sol = solve(prob,Rodas4(), saveat = 5, maxiters=1e7)
scatter(sol[1,:])

dudt = Chain(Dense(2,50,relu),
             Dense(50,2,relu),
             x -> [exp(20.0-150000.0/8.3146./x[2]).*x[1]; 1.0])
ps = Flux.params(dudt)

u0 = [0.05;300.0]
n_ode = x -> neural_ode(dudt, x, tspan, Tsit5(), saveat = 300.0:5:1200)

sol_data = Array(sol)
function predict_n_ode()
  n_ode(u0)
end

loss_n_ode() = sum(abs2, predict_n_ode()[1,:] - sol_data[1,:])

data = Iterators.repeated((), 50)
opt = ADAM(0.01)

cb = function()
  display(loss_n_ode())
end

cb()
Flux.train!(loss_n_ode, ps, data, opt, cb=cb)

Problems with custom layers

Hey there!

Thanks for this awsome package! I really love the DiffEq environment ( which made me switch to Julia as language of choice ).

I tried to implement a custom layer as a neural ode. Here is a mwe:

using LinearAlgebra
using DifferentialEquations
using Flux, DiffEqFlux

struct SineLayer{S,T}
    C::S
    W::T
end

function sineLayer(in::Integer, out::Integer)
    C = Array(Diagonal(ones(in)))
    W = param(randn(out, in))
    return SineLayer(C, W)
end

function (a::SineLayer)(x)
    return a.W*sin.(a.C*x)
end

Flux.@treelike SineLayer


model = sineLayer(2,2)
n_ode = x->neural_ode(model, x, (0.0f0, 1.0f0), Tsit5(), saveat = [1.0f0])
init = param([0.0f0, 1.0f0])
lossf() = norm(n_ode(init) .- 2.0)
lossf()
opt = ADAM(0.1)
Flux.train!(lossf, params(model),Iterators.repeated((), 1), opt)

Which throws the following error and stacktrace:

MethodError: *(::Transpose{Float64,Array{Float64,2}}, ::TrackedArray{…,Array{Float64,1}}) is ambiguous. Candidates:
  *(x::AbstractArray{T,2} where T, y::TrackedArray{T,1,A} where A where T) in Tracker at /home/dfki.uni-bremen.de/jmartensen/.julia/packages/Tracker/RRYy6/src/lib/array.jl:383
  *(transA::Transpose{#s623,#s622} where #s622<:AbstractArray{T,2} where #s623, x::AbstractArray{S,1}) where {T, S} in LinearAlgebra at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/matmul.jl:84
Possible fix, define
  *(::Transpose{#s623,#s622} where #s622<:AbstractArray{T,2} where #s623, ::TrackedArray{S,1,A} where A)
(::getfield(Tracker, Symbol("##509#510")){Array{Float64,2},TrackedArray{,Array{Float32,1}}})(::TrackedArray{…,Array{Float64,1}}) at array.jl:391
back_(::Tracker.Grads, ::Tracker.Call{getfield(Tracker, Symbol("##509#510")){Array{Float64,2},TrackedArray{,Array{Float32,1}}},Tuple{Nothing,Tracker.Tracked{Array{Float32,1}}}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:110
back(::Tracker.Grads, ::Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:125
foreach at back.jl:113 [inlined]
back_(::Tracker.Grads, ::Tracker.Call{getfield(Tracker, Symbol("#back#548")){1,typeof(sin),Tuple{TrackedArray{,Array{Float64,1}}}},Tuple{Tracker.Tracked{Array{Float64,1}}}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:113
back(::Tracker.Grads, ::Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:125
(::getfield(Tracker, Symbol("##16#17")){Tracker.Grads})(::Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:113
foreach(::Function, ::Tuple{Tracker.Tracked{Array{Float64,2}},Tracker.Tracked{Array{Float64,1}}}, ::Tuple{TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}) at abstractarray.jl:1867
back_(::Tracker.Grads, ::Tracker.Call{getfield(Tracker, Symbol("##509#510")){TrackedArray{,Array{Float64,2}},TrackedArray{,Array{Float64,1}}},Tuple{Tracker.Tracked{Array{Float64,2}},Tracker.Tracked{Array{Float64,1}}}}, ::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}) at back.jl:113
back(::Tracker.Grads, ::Tracker.Tracked{Array{Float64,1}}, ::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}) at back.jl:125
#18 at back.jl:140 [inlined]
(::getfield(Tracker, Symbol("##21#23")){getfield(Tracker, Symbol("##18#19")){Tracker.Params,TrackedArray{,Array{Float64,1}}}})(::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}) at back.jl:149
(::DiffEqSensitivity.ODEAdjointSensitivityFunction{Array{Float32,1},Array{Float32,1},Array{Float32,1},ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#30")){SineLayer{Array{Float64,2},TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Nothing,UniformScaling{Bool},Nothing,Nothing,Nothing,Array{Float32,1},ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float64,1},ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#30")){SineLayer{Array{Float64,2},TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#30")){SineLayer{Array{Float64,2},TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing})(::Array{Float32,1}, ::Array{Float32,1}, ::Array{Float64,1}, ::Float32) at adjoint_sensitivity.jl:136
ODEFunction at diffeqfunction.jl:193 [inlined]
initialize!(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Float32,1},Float32,Array{Float64,1},Float32,Float32,Float32,Array{Array{Float32,1},1},ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float64,1},ODEFunction{true,DiffEqSensitivity.ODEAdjointSensitivityFunction{Array{Float32,1},Array{Float32,1},Array{Float32,1},ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#30")){SineLayer{Array{Float64,2},TrackedArray{,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Nothing,UniformScaling{Bool},Nothing,Nothing,Nothing,Array{Float32,1},ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,...

Just using flux seems to work though

lossf2() = norm(model(init) .-2.0)
lossf2()
Flux.train!(lossf2, params(model),Iterators.repeated((), 1), opt)

Likewise a normal dense layer does not throw any errors.

My current working environment ( Julia 1.1.0 ) includes

    Status `~/Documents/code/julia/Test/Project.toml`
  [aae7a2af] + DiffEqFlux v0.5.0
  [0c46a032] + DifferentialEquations v6.4.0
  [634d3b9d] + DrWatson v0.5.3
  [7c1d4256] + DynamicPolynomials v0.3.0
  [61744808] + DynamicalSystems v1.3.0
  [587475ba] + Flux v0.8.3
  [a98d9a8b] + Interpolations v0.12.2
  [c8e1da08] + IterTools v1.1.1
  [d96e819e] + Parameters v0.10.3
  [91a5bcdd] + Plots v0.25.2
    Status `~/Documents/code/julia/Test/Manifest.toml`
  [49dc2e85] + Calculus v0.4.1
  [717857b8] + DSP v0.5.2
  [aae7a2af] + DiffEqFlux v0.5.0
  [0c46a032] + DifferentialEquations v6.4.0
  [634d3b9d] + DrWatson v0.5.3
  [7c1d4256] + DynamicPolynomials v0.3.0
  [61744808] + DynamicalSystems v1.3.0
  [587475ba] + Flux v0.8.3
  [a98d9a8b] + Interpolations v0.12.2
  [c8e1da08] + IterTools v1.1.1
  [d96e819e] + Parameters v0.10.3
  [91a5bcdd] + Plots v0.25.2

Any help is appreciated!

ERROR: LoadError: MethodError: objects of type Tracker.Params are not callable - partial_neural.jl

Hello,
when running partial_neural.jl I found the following issue:

66.70833f0 (tracked)
ERROR: LoadError: MethodError: objects of type Tracker.Params are not callable
Stacktrace:
[1] top-level scope at C:\Users\frank\Documents\GitHub\DiffEqFlux.jl\test\partial_neural.jl:38
in expression starting at C:\Users\frank\Documents\GitHub\DiffEqFlux.jl\test\partial_neural.jl:38

I am not sure why it happens.
Thank you

Unable to Take Gradients with Batch Dimension

Here is an 1D problem where the model is trying to learn the function f(u)=u.^3. The data is a (1,200)-dimensional array, where 200 is the batch size.

using Flux, DiffEqFlux, OrdinaryDiffEq, StatsBase, RecursiveArrayTools
using Distributions

const tspan = (0.0f0,25f0)
const RANGE = (Float32(-8.),Float32(8.))
const BS = 200

target(u) = u.^3

function gen_data(batchsize,target_fun)
    x = Float32.(rand(Uniform(RANGE...,),batchsize))'|>collect
    return x,target(x)
end

dudt = Chain(Dense(1,20,tanh),Dense(20,20,tanh),Dense(20,1))
function n_ode(u0)
    neural_ode(dudt,u0,tspan,Tsit5(),save_start=false,saveat=[tspan[2]],reltol=1e-5,abstol=1e-12)
end
ps = Flux.params(dudt)

loss(x,y) = mean(abs.(n_ode(x)-y))

data = Iterators.repeated(gen_data(BS,target), 10000) 
opt = ADAM(0.001)
cb = function () #callback function to observe training
    tx,ty = gen_data(BS,target)
    display(loss(tx,ty))
end

# Display the ODE with the initial parameter values.
cb()

Flux.train!(loss, ps, data, opt, cb = cb)

The cb() shows that this is able to do a forward pass, however the reverse pass returns the error:

ERROR: MethodError: no method matching DiffEqSensitivity.ODEAdjointSensitivityFunction
... <- This method error is huge
You might have used a 2d row vector where a 1d column vector was required.
Note the difference between 1d column vector [1,2,3] and 2d row vector [1 2 3].
You can convert to a column vector with the vec() function.

This is the standard way Flux expects batched data.
For example if the model was just simply the dudt chain, and not integrating that with a solver, i.e.
loss(x,y) = mean(abs.(dudt(x)-y)) this works and trains fine.

Additionally, if I try neural_ode_rd instead even the forward pass (cb()) won't work and it will return the error:

ERROR: MethodError: no method matching Array{Float32,1}(::Array{Float32,2})

You might have used a 2d row vector where a 1d column vector was required.
Note the difference between 1d column vector [1,2,3] and 2d row vector [1 2 3].
You can convert to a column vector with the vec() function.

Using reverse-mode autodiff with a second order ODE problem

So I was trying reverse-mode autodiff with a second order ODE problem. Here's what I have so far:

using DifferentialEquations
using Flux
using DiffEqFlux

u0 = Float32[0.; 2.]
du0 = Float32[0.; 0.]
tspan = (0.0f0, 1.0f0)
t = range(tspan[1], tspan[2], length=20)
p = param(Float32[])

model = Chain(Dense(2, 50, tanh), Dense(50, 2))

function f(du::TrackedArray, u::TrackedArray, p, t)
    model(u)
end
function f(du::AbstractArray, u::AbstractArray, p, t)
    Flux.data(model(u))
end

prob = SecondOrderODEProblem{false}(f, du0, u0, tspan, p)

function predict_rd()
    Flux.Tracker.collect(diffeq_rd(p, prob, Tsit5(), u0=(param(du0), param(u0)), saveat=t))
end

correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:1)[2:end])))

loss_n_ode() = sum(abs2, correct_pos .- predict_rd()[1:2, :])

data = Iterators.repeated((), 1000)
opt = ADAM(0.1)

cb = function ()
    println(loss_n_ode())
    loss_n_ode() < 0.01 && Flux.stop()
end

Flux.train!(loss_n_ode, params(model), data, opt, cb=cb)

This fails to run and the problem seems to be the u0=(param(du0), param(u0)) argument to diffeq_rd. If I omit this argument or use u0=(du0, u0) then the initial conditions of the ODE problem, i.e. the untracked versions, are used. This means the ODE can be solved but the training doesn't change the neural network parameters.
Any ideas of how I can get round this, or am I doing something else wrong?

MethodError: no method matching Float64(::Tracker.TrackedReal{Float64}) when using Array instead of Vector as parameters

Hello,
first of all I am not sure whether I am posting this to the right place..

I wanted to try use DiffEqFlux.jl to train a ODE system similar to 1st example
but with an outside effect (something like Robin's BC).
But when I make the p an Array instead of Vector (as it is done in the example)
I get an error. I was not able to find out where it comes from.
What am I missing here? Is this expected behavior?
Do I need to manually map the indexes in the parameter Vector inside the model function?
Are the destructure/restructure tools somehow supposed to deal with this?

Thanks for any reply.

# I need to build and train model that looks something like this:
#
# function model(du,u,p,t)
#   du = Array_of_parameters*u + static_coefficient.*vector_of_parameters.*(u.-u0(t))
# end
#
# But when I tweak the 1st DiffEqFlux.jl example to work with Array 
# instead of Vector I get error

using DifferentialEquations

# original example
#p = [1.5,1.0,3.0,1.0]
#function lotka_volterra(du,u,p,t)
#  x, y = u
#  α, β, δ, γ = p
#  du[1] = dx = α*x - β*x*y
#  du[2] = dy = -δ*y + γ*x*y
#end

# this still works
#p = [1.5,1.0,3.0,1.0]
#function lotka_volterra(du,u,p,t)
#  du[1] = p[1]*u[1] - p[2]*u[1]*u[2]
#  du[2] = -p[3]*u[2] + p[4]*u[1]*u[2]
#end

# this version throws > no method matching Float64(::Tracker.TrackedReal{Float64})
p = [1.5 1.0;3.0 1.0]
function lotka_volterra(du,u,p,t)
  du[1] = p[1,1]*u[1] - p[1,2]*u[1]*u[2]
  du[2] = -p[2,1]*u[2] + p[2,2]*u[1]*u[2]
end

u0 = [1.0,1.0]
tspan = (0.0,10.0)

prob = ODEProblem(lotka_volterra,u0,tspan,p)
sol = solve(prob,Tsit5())
using Plots
plot(sol)

using Flux, DiffEqFlux
#p = param([2.2,1.0,2.0,0.4]) # Original Initial Parameter Vector
p = param([2.2 1.0;2.0 0.4]) # Tweaked Initial Parameter Array
params = Flux.Params([p])

function predict_adjoint() # Our 1-layer neural network
  diffeq_adjoint(p,prob,Tsit5(),saveat=0.0:0.1:10.0)
end

loss_adjoint() = sum(abs2,x-1 for x in predict_adjoint())

data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function () #callback function to observe training
  display(loss_adjoint())
  # using `remake` to re-create our `prob` with current parameters `p`
  # display(plot(solve(remake(prob,p=Flux.data(p)),Tsit5(),saveat=0.0:0.1:10.0),ylim=(0,6)))
end

predict_adjoint() # <= MethodError 

# Display the ODE with the initial parameter values.
cb()
Flux.train!(loss_adjoint, params, data, opt, cb = cb)

MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})
Closest candidates are:
Float64(::Real, !Matched::RoundingMode) where T<:AbstractFloat at rounding.jl:194
Float64(::T<:Number) where T<:Number at boot.jl:741
Float64(!Matched::Int8) at float.jl:60
...
convert(::Type{Float64}, ::Tracker.TrackedReal{Float64}) at number.jl:7
setindex!(::Array{Float64,1}, ::Tracker.TrackedReal{Float64}, ::Int64) at array.jl:767
lotka_volterra(::Array{Float64,1}, ::Array{Float64,1}, ::TrackedArray{…,Array{Float64,2}}, ::Float64) at NeuralStuff_tweaked_example.jl:23
(::ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing})(::Array{Float64,1}, ::Array{Float64,1}, ::Vararg{Any,N} where N) at diffeqfunction.jl:230
initialize!(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Float64,1},Float64,TrackedArray{…,Array{Float64,2}},Float64,Float64,Float64,Array{Array{Float64,1},1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,TrackedArray{…,Array{Float64,2}},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Array{Float64,1},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float64,1}},Array{Float64,1},Float64,Nothing}, ::OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}) at low_order_rk_perform_step.jl:623
#__init#335(::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Array{Float64,1}, ::Array{Float64,1}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::Nothing, ::Bool, ::Bool, ::Float64, ::Float64, ::Float64, ::Bool, ::Bool, ::Rational{Int64}, ::Nothing, ::Nothing, ::Rational{Int64}, ::Int64, ::Int64, ::Int64, ::Rational{Int64}, ::Bool, ::Int64, ::Nothing, ::Nothing, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,TrackedArray{…,Array{Float64,2}},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem}, ::Tsit5, ::Array{Array{Float64,1},1}, ::Array{Float64,1}, ::Array{Any,1}, ::Type{Val{true}}) at solve.jl:356
(::getfield(DiffEqBase, Symbol("#kw##__init")))(::NamedTuple{(:saveat,),Tuple{StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,TrackedArray{…,Array{Float64,2}},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem}, ::Tsit5, ::Array{Array{Float64,1},1}, ::Array{Float64,1}, ::Array{Any,1}, ::Type{Val{true}}) at none:0
#__init at none:0 [inlined]
#__init at none:0 [inlined]
#__init at none:0 [inlined]
#__solve#334 at solve.jl:4 [inlined]
#__solve at none:0 [inlined]
#solve_call#435(::Base.Iterators.Pairs{Symbol,StepR...

HamiltonianProblem not supported

Hi, thanks for the awesome package! Unfortunately I'm having issues using it for my work. I basically want to integrate a set of equations of motions defined by a Hamiltonian H(p, x) that is too complex to differentiate by hand. The code below sets up the problem:

using DifferentialEquations
using DiffEqFlux
using Flux

function H(p, x, params)  # complicated (randomly selected) Hamiltonian
    return params[1] * (1 + x[1]^2)^(1/3) * (1 + p[1]^2)^(1/4)
end

function getprob(params)  # set up problem
    tspan = (0.0, 1.0)
    x0, p0 = [0.01], [0.05]
    return HamiltonianProblem(H, p0, x0, tspan, params)
end

Now I want to differentiate solutions of the equations of motion with respect to the parameter vector. The function below sets this up and tries to run the ODE solver:

function testsolve(p0=[0.1])
    prob = getprob(p0)  # set up problem

    p = param(p0)  # initialize parameter

    function predict_rd()
        Tracker.collect(diffeq_rd(p, prob, Tsit5(), saveat=0.1))
    end
    
    println("test: ", predict_rd())  # try running the diff eq solver
end

When I run testsolve() I get the error

ERROR: MethodError: *(::Tracker.TrackedReal{Float64}, ::ForwardDiff.Dual{ForwardDiff.Tag{DiffEqPhysics.PhysicsTag,Float64},Float64,1}) is ambiguous. Candidates:
  *(a::Tracker.TrackedReal, b::Real) in Tracker at /Users/acoogan/.julia/packages/Tracker/RRYy6/src/lib/real.jl:94
  *(x::Real, y::ForwardDiff.Dual{Ty,V,N} where N where V) where Ty in ForwardDiff at /Users/acoogan/.julia/packages/ForwardDiff/N0wMF/src/dual.jl:140
Possible fix, define
  *(::Tracker.TrackedReal, ::ForwardDiff.Dual{Ty,V,N} where N where V)

followed by a massive stack trace.

It seems like there's a conflict between the datatypes used to define the equations of motion and the parameters tracked by Flux. Is there a way to fix this?

I was not able to get this to work using Zygote rather than ForwardDiff to differentiate H -- hopefully I'm missing something about how to use nested autodifferentiation?

regressing 1-x/x0

Hi Chris,

nice package and thanks for the swift reply last week!
I have modified the neuro ODE example to regress a simple affine function y = 1-x/x_0 (here x_0 = datasize-1). I found that the neuro ODE implementation is unable to regress this very simple example. I was wondering what might be the problem here.

Thanks a lot,
Rand

using Flux, DiffEqFlux, DifferentialEquations, Plots
using DataFrames, CSV

datasize = 10
Q = range(1.0f0, stop = 0.0f0, length = datasize) |> collect

u0 = Q[1,:]
tspan = (0.0f0,convert(Float32, datasize-1))
t = range(tspan[1],tspan[2],length=datasize)

dudt = Chain(x ->x, Dense(1,10,tanh),
             Dense(10,1))
ps = Flux.params(dudt)
n_ode = x->neural_ode(dudt,x,tspan,Tsit5(),saveat=t,reltol=1e-5,abstol=1e-7)

pred = n_ode(u0) # Get the prediction using the correct initial condition
scatter(t,Q,label="data")
scatter!(t,Flux.data(pred[1,:]),label="prediction")

function predict_n_ode()
  n_ode(u0)
end
loss_n_ode() = sum(abs2,Q.- predict_n_ode())

data = Iterators.repeated((), 2000)
opt = ADAM(0.1)

cb = function () #callback function to observe training
  display(loss_n_ode())
  # plot current prediction against data
  cur_pred = Flux.data(predict_n_ode())
  pl = scatter(t,Q,label="data")
  scatter!(pl,t,cur_pred[1,:],label="prediction")
  display(plot(pl))
end

# Display the ODE with the initial parameter values.
cb()

Flux.train!(loss_n_ode, ps, data, opt, cb = cb)

Derivative w.r.t initial conditions with saveat for end point

Thanks a lot for your work! I wonder if it could support taking derivative w.r.t initial conditions, such as u0 = param([1.0, 1.0]), rather than p = param([1.5, 1.0, 3.0, 1.0]).

Today I noticed that the code in README.md cannot run properly because of some update of upstream packages, so I cannot make an example of taking derivative w.r.t initial conditions now. Last time I tried it, it throws error.

Mixing neural and normal ODEs

Hi! Firstly, thanks for a wonderful package, I expect to see a lot of applications come out of this!

My problem refers to #15. The solution that was posted at the end works. However, if I try to extend the idea to a higher-dimensional ODE, this doesn't work. Here's an example:

using DiffEqFlux, Flux,  DifferentialEquations

# Generate ODE---------------------------------------------------
par = Float32[10.0, 28.0,8/3]
tspan = (0.0,10.0)
Kgain = Chain(Dense(2,10,tanh),
              Dense(10,10,tanh),
              Dense(10,2))

function dudt_(u::TrackedArray,p,tt)
  nn = Kgain(u[1:2])
  Flux.Tracker.collect([nn;[u[1]*u[2] - p[3]*u[3]]])
end
function dudt_(u::AbstractArray,p,tt)
  nn = Flux.data(Kgain(u[1:2]))
  collect([nn;[u[1]*u[2] - p[3]*u[3]]])
end

pr = param(par)
x0 = Float32[0.1;0.1;0.1]
prob = ODEProblem(dudt_,x0,tspan,pr)
diffeq_rd(pr,prob,Tsit5())

This gives the following error:
MethodError: vcat(::TrackedArray{…,Array{Float32,1}}, ::Array{Tracker.TrackedReal{Float32},1}) is ambiguous. Candidates: vcat(V::AbstractArray{T,1} where T...) in Base at abstractarray.jl:1233 vcat(A::Union{AbstractArray{T,1}, AbstractArray{T,2}} where T...) in Base at abstractarray.jl:1296 vcat(A::Union{UniformScaling, Union{AbstractArray{T,1}, AbstractArray{T,2}} where T}...) in LinearAlgebra at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/uniformscaling.jl:258 vcat(A::AbstractArray, B::AbstractArray) in Base at abstractarray.jl:1517 vcat(x::Union{TrackedArray, TrackedReal}, xs::Union{Number, AbstractArray}...) in Tracker at /home/manu/.julia/packages/Tracker/6wcYJ/src/lib/array.jl:167 Possible fix, define vcat(::TrackedArray{T,1,A} where A<:AbstractArray{T,1} where T, ::AbstractArray{T,1} where T)

Any help is appreciated! Thanks in advance :)

Derivatives of the neural_ode

I accidentally stopped the tracking in the neural_ode layer and of course it can't backprop through since what comes out is just a float (@MikeInnes told me this would happen). So an MWE of trying to train with the neural_ode is:

using OrdinaryDiffEq, DiffEqFlux, Flux, Plots
function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end
u0 = [1.0,1.0]
tspan = (0.0,10.0)
p = [1.5,1.0,3.0,1.0]
prob = ODEProblem(lotka_volterra,u0,tspan,p)
ode_data = Array(solve(prob,Tsit5(),saveat=0.1))

dudt = Chain(Dense(2,50,tanh),Dense(50,2))
tspan = (0.0f0,10.0f0)
n_ode = x->neural_ode(x,dudt,tspan,Tsit5(),saveat=0.1)
pred = n_ode(u0)

scatter(0.0:0.1:10.0,ode_data[1,:],label="data")
scatter!(0.0:0.1:10.0,pred[1,:],label="prediction")

function predict_n_ode()
  n_ode(u0)
end
data
predict_n_ode()

loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())

data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function () #callback function to observe training
  display(loss_n_ode())
  # plot current prediction against data
  cur_pred = predict_n_ode()
  pl = scatter(0.0:0.1:10.0,ode_data[1,:],label="data")
  scatter!(pl,0.0:0.1:10.0,cur_pred[1,:],label="prediction")
  plot(pl)
end

# Display the ODE with the initial parameter values.
cb()

Flux.train!(loss_n_ode, params, data, opt, cb = cb)

Type issue when optimizing over initial conditions

Hi! First of all, thank you very much for this package! I've been working in the connection between dynamical systems and machine learning for a couple of years and this package will be extremely helpful for my research! Actually what I've been doing is quite similar to your model-zoo example ode.jl but with a more complex system.

I've been playing with your example ode.jl and I found that by simply including the initial conditions as parameters for Flux to optimize, it arises an error about type conversion which I couldn't fix:

using Flux, DiffEqFlux, DifferentialEquations, Plots

## Setup ODE to optimize
function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

p = [1.5,1.0,3.0,1.0,1.0,1.0]
u0_f(p,t0) = [p[5],p[6]]
tspan = (0.0,10.0)

prob = ODEProblem(lotka_volterra,u0_f,tspan,p)

# Verify ODE solution
sol = solve(prob,Tsit5())
plot(sol)

# Generate data from the ODE
sol = solve(prob,Tsit5(),saveat=0.1)
A = sol[1,:] # length 101 vector
t = 0:0.1:10.0
scatter!(t,A)

# Build a neural network that sets the cost as the difference from the
# generated data and 1

p = param([2.2, 1.0, 2.0, 0.4,1.0,1.5]) # Initial Parameter Vector
function predict_rd() # Our 1-layer neural network
  diffeq_rd(p,prob,Tsit5(),saveat=0.1)[1,:]
end
loss_rd() = sum(abs2,x-1 for x in predict_rd()) # loss function

# Optimize the parameters so the ODE's solution stays near 1

data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function () #callback function to observe training
  display(loss_rd())
  # using `remake` to re-create our `prob` with current parameters `p`
  display(plot(solve(remake(prob,p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
end
# Display the ODE with the initial parameter values.
cb()
Flux.train!(loss_rd, [p], data, opt, cb = cb)

Gives the error:

LoadError: MethodError: Cannot `convert` an object of type typeof(u0_f) to an object of type Flux.Tracker.TrackedReal{Float64}
Closest candidates are:
  convert(::Type{Flux.Tracker.TrackedReal{T}}, !Matched::Flux.Tracker.TrackedReal{T}) where T at /Users/ger/.julia/packages/Flux/8XpDt/src/tracker/lib/real.jl:35
  convert(::Type{Flux.Tracker.TrackedReal{T}}, !Matched::Flux.Tracker.TrackedReal{S}) where {T, S} at /Users/ger/.julia/packages/Flux/8XpDt/src/tracker/lib/real.jl:39
  convert(::Type{Flux.Tracker.TrackedReal{T}}, !Matched::Real) where T at /Users/ger/.julia/packages/Flux/8XpDt/src/tracker/lib/real.jl:37
  ...
in expression starting at /Users/ger/Documents/DiffEqFlux/model-zoo/ode_testing2_for_issues.jl:50
_broadcast_getindex_evalf at broadcast.jl:578 [inlined]
_broadcast_getindex at broadcast.jl:561 [inlined]
getindex at broadcast.jl:511 [inlined]
copy at broadcast.jl:763 [inlined]
materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(convert),Tuple{Base.RefValue{Type{Flux.Tracker.TrackedReal{Float64}}},Base.RefValue{typeof(u0_f)}}}) at broadcast.jl:753
#diffeq_rd#1(::Function, ::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}}, ::Function, ::TrackedArray{…,Array{Float64,1}}, ::ODEProblem{typeof(u0_f),Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::Tsit5) at layers.jl:8
(::getfield(DiffEqFlux, Symbol("#kw##diffeq_rd")))(::NamedTuple{(:saveat,),Tuple{Float64}}, ::typeof(diffeq_rd), ::TrackedArray{…,Array{Float64,1}}, ::ODEProblem{typeof(u0_f),Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::Tsit5) at none:0
predict_rd() at ode_testing2_for_issues.jl:36
loss_rd() at ode_testing2_for_issues.jl:38
(::getfield(Main, Symbol("##45#46")))() at ode_testing2_for_issues.jl:45
top-level scope at none:0

I also found some other issues regarding solutions instabilities while optimizing with Flux which also arise errors but I guess I should comment it in a separate Issue.

Thanks!
Germán

Problems in "Training a Neural ODE" example

One small problem is that cur_pred inside the cb function is tracked and has to be wrapped in Flux.data. After doing this, calling cb() works fine and correctly shows the plot.

A bigger problem is that the call to Flux.train! fails after a long wait with a very intimidating error message, included at the bottom.

This is on the CPU, with Julia 1.1, on macOS Mojave.

julia> Flux.train!(loss_n_ode, params, data, opt, cb = cb)

ERROR: BoundsError: attempt to access 0-element StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}} at index [0]
Stacktrace:
 [1] throw_boundserror(::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Tuple{Int64}) at ./abstractarray.jl:484
 [2] checkbounds at ./abstractarray.jl:449 [inlined]
 [3] getindex at ./range.jl:638 [inlined]
 [4] tstop_saveat_disc_handling at /Users/curry/.julia/packages/OrdinaryDiffEq/suHr2/src/solve.jl:383 [inlined]
 [5] #__init#458(::Int64, ::Float64, ::Array{Float32,1}, ::Array{Float32,1}, ::Nothing, ::Bool, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::Bool, ::Bool, ::Float32, ::Bool, ::Rational{Int64}, ::Float64, ::Float64, ::Int64, ::Rational{Int64}, ::Int64, ::Int64, ::Rational{Int64}, ::Bool, ::Int64, ::Nothing, ::Nothing, ::Int64, ::Float32, ::Float32, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,DiffEqSensitivity.ODEAdjointSensitivityFunction{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,Nothing,DiffEqDiffTools.UJacobianWrapper{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Float32,Array{Float32,1}},Nothing,DiffEqSensitivity.UGradientWrapper{getfield(DiffEqFlux, Symbol("#df#21")),Float32,Array{Float32,1}},Nothing,Nothing,SensitivityAlg{0,true,Val{:central}},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float32,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float32}}}},Nothing,Float64,LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Array{Float32,1},Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},CallbackSet{Tuple{},Tuple{DiscreteCallback{getfield(DiffEqCallbacks, Symbol("##31#34")){Base.RefValue{Union{Nothing, Float32}}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#21")),Bool,Array{Float32,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},getfield(DiffEqCallbacks, Symbol("##33#36")){typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},Base.RefValue{Union{Nothing, Float32}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#21")),Bool,Array{Float32,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/curry/.julia/packages/OrdinaryDiffEq/suHr2/src/solve.jl:150
 [6] (::getfield(DiffEqBase, Symbol("#kw##__init")))(::NamedTuple{(:abstol, :reltol, :save_everystep, :saveat),Tuple{Float64,Float64,Bool,Float64}}, ::typeof(DiffEqBase.__init), ::ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,DiffEqSensitivity.ODEAdjointSensitivityFunction{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,Nothing,DiffEqDiffTools.UJacobianWrapper{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Float32,Array{Float32,1}},Nothing,DiffEqSensitivity.UGradientWrapper{getfield(DiffEqFlux, Symbol("#df#21")),Float32,Array{Float32,1}},Nothing,Nothing,SensitivityAlg{0,true,Val{:central}},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float32,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float32}}}},Nothing,Float64,LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Array{Float32,1},Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},CallbackSet{Tuple{},Tuple{DiscreteCallback{getfield(DiffEqCallbacks, Symbol("##31#34")){Base.RefValue{Union{Nothing, Float32}}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#21")),Bool,Array{Float32,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},getfield(DiffEqCallbacks, Symbol("##33#36")){typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},Base.RefValue{Union{Nothing, Float32}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#21")),Bool,Array{Float32,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at ./none:0
 [7] #__solve#457(::Base.Iterators.Pairs{Symbol,Real,NTuple{4,Symbol},NamedTuple{(:abstol, :reltol, :save_everystep, :saveat),Tuple{Float64,Float64,Bool,Float64}}}, ::Function, ::ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,DiffEqSensitivity.ODEAdjointSensitivityFunction{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,Nothing,DiffEqDiffTools.UJacobianWrapper{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Float32,Array{Float32,1}},Nothing,DiffEqSensitivity.UGradientWrapper{getfield(DiffEqFlux, Symbol("#df#21")),Float32,Array{Float32,1}},Nothing,Nothing,SensitivityAlg{0,true,Val{:central}},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float32,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float32}}}},Nothing,Float64,LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Array{Float32,1},Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},CallbackSet{Tuple{},Tuple{DiscreteCallback{getfield(DiffEqCallbacks, Symbol("##31#34")){Base.RefValue{Union{Nothing, Float32}}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#21")),Bool,Array{Float32,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},getfield(DiffEqCallbacks, Symbol("##33#36")){typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},Base.RefValue{Union{Nothing, Float32}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#21")),Bool,Array{Float32,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/curry/.julia/packages/OrdinaryDiffEq/suHr2/src/solve.jl:6
 [8] #__solve at ./none:0 [inlined] (repeats 5 times)
 [9] #solve#442(::Base.Iterators.Pairs{Symbol,Real,NTuple{4,Symbol},NamedTuple{(:abstol, :reltol, :save_everystep, :saveat),Tuple{Float64,Float64,Bool,Float64}}}, ::Function, ::ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,DiffEqSensitivity.ODEAdjointSensitivityFunction{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,Nothing,DiffEqDiffTools.UJacobianWrapper{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Float32,Array{Float32,1}},Nothing,DiffEqSensitivity.UGradientWrapper{getfield(DiffEqFlux, Symbol("#df#21")),Float32,Array{Float32,1}},Nothing,Nothing,SensitivityAlg{0,true,Val{:central}},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float32,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float32}}}},Nothing,Float64,LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Array{Float32,1},Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},CallbackSet{Tuple{},Tuple{DiscreteCallback{getfield(DiffEqCallbacks, Symbol("##31#34")){Base.RefValue{Union{Nothing, Float32}}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#21")),Bool,Array{Float32,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},getfield(DiffEqCallbacks, Symbol("##33#36")){typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},Base.RefValue{Union{Nothing, Float32}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#21")),Bool,Array{Float32,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5) at /Users/curry/.julia/packages/DiffEqBase/8usQ9/src/solve.jl:39
 [10] (::getfield(DiffEqBase, Symbol("#kw##solve")))(::NamedTuple{(:abstol, :reltol, :save_everystep, :saveat),Tuple{Float64,Float64,Bool,Float64}}, ::typeof(solve), ::ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,DiffEqSensitivity.ODEAdjointSensitivityFunction{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,Nothing,DiffEqDiffTools.UJacobianWrapper{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Float32,Array{Float32,1}},Nothing,DiffEqSensitivity.UGradientWrapper{getfield(DiffEqFlux, Symbol("#df#21")),Float32,Array{Float32,1}},Nothing,Nothing,SensitivityAlg{0,true,Val{:central}},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float32,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float32}}}},Nothing,Float64,LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Array{Float32,1},Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},CallbackSet{Tuple{},Tuple{DiscreteCallback{getfield(DiffEqCallbacks, Symbol("##31#34")){Base.RefValue{Union{Nothing, Float32}}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#21")),Bool,Array{Float32,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},getfield(DiffEqCallbacks, Symbol("##33#36")){typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},Base.RefValue{Union{Nothing, Float32}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float32,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#21")),Bool,Array{Float32,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5) at ./none:0
 [11] #adjoint_sensitivities#24(::Float64, ::Float64, ::Float64, ::Float64, ::SensitivityAlg{0,true,Val{:central}}, ::Array{Float32,1}, ::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}}, ::Function, ::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float32,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float32}}}}, ::Tsit5, ::getfield(DiffEqFlux, Symbol("#df#21")), ::Array{Float32,1}, ::Nothing) at /Users/curry/.julia/packages/DiffEqSensitivity/wbHxJ/src/adjoint_sensitivity.jl:318
 [12] (::getfield(DiffEqSensitivity, Symbol("#kw##adjoint_sensitivities")))(::NamedTuple{(:sensealg, :saveat),Tuple{SensitivityAlg{0,true,Val{:central}},Float64}}, ::typeof(adjoint_sensitivities), ::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float32,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float32}}}}, ::Tsit5, ::Function, ::Array{Float32,1}, ::Nothing) at ./none:0 (repeats 2 times)
 [13] (::getfield(DiffEqFlux, Symbol("##18#20")){Bool,Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}},Tuple{Tsit5},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float32,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float32}}}}})(::Array{Float64,2}) at /Users/curry/.julia/packages/DiffEqFlux/FVg0B/src/Flux/layers.jl:52
 [14] back_(::Flux.Tracker.Call{getfield(DiffEqFlux, Symbol("##18#20")){Bool,Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}},Tuple{Tsit5},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float32,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float32}}}}},Tuple{Flux.Tracker.Tracked{Array{Float32,1}},Nothing,Nothing}}, ::Array{Float64,2}, ::Bool) at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:23
 [15] back(::Flux.Tracker.Tracked{Array{Float64,2}}, ::Array{Float64,2}, ::Bool) at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:46
 [16] #2 at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:26 [inlined]
 [17] foreach at ./abstractarray.jl:1867 [inlined]
 [18] back_(::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("#back#491")){2,typeof(-),Tuple{Array{Float64,2},TrackedArray{…,Array{Float64,2}}}},Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,2}}}}, ::Array{Float64,2}, ::Bool) at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:26
 [19] back(::Flux.Tracker.Tracked{Array{Float64,2}}, ::Array{Float64,2}, ::Bool) at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:46
 [20] foreach at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:26 [inlined]
 [21] back_(::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("#back#491")){1,typeof(abs2),Tuple{TrackedArray{…,Array{Float64,2}}}},Tuple{Flux.Tracker.Tracked{Array{Float64,2}}}}, ::Array{Float64,2}, ::Bool) at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:26
 [22] back(::Flux.Tracker.Tracked{Array{Float64,2}}, ::Array{Float64,2}, ::Bool) at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:46
 [23] foreach at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:26 [inlined]
 [24] back_(::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##431#432")){TrackedArray{…,Array{Float64,2}}},Tuple{Flux.Tracker.Tracked{Array{Float64,2}}}}, ::Float64, ::Bool) at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:26
 [25] back(::Flux.Tracker.Tracked{Float64}, ::Int64, ::Bool) at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:46
 [26] back!(::Flux.Tracker.TrackedReal{Float64}) at /Users/curry/.julia/packages/Flux/U8AZD/src/tracker/back.jl:65
 [27] macro expansion at /Users/curry/.julia/packages/Flux/U8AZD/src/optimise/train.jl:21 [inlined]
 [28] macro expansion at /Users/curry/.julia/packages/Juno/nDCSn/src/progress.jl:124 [inlined]
 [29] #train!#5(::getfield(Main, Symbol("##11#12")), ::Function, ::Function, ::Function, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::ADAM) at /Users/curry/.julia/packages/Flux/U8AZD/src/optimise/train.jl:68
 [30] (::getfield(Flux.Optimise, Symbol("#kw##train!")))(::NamedTuple{(:cb,),Tuple{getfield(Main, Symbol("##11#12"))}}, ::typeof(Flux.Optimise.train!), ::Function, ::Function, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::ADAM) at ./none:0
 [31] top-level scope at none:0

Implicit differential equation solver with internal autodiff of the Jacobian

Hello,

I keep getting the WARNING: Instability detected. Aborting error. when I am doing training. I think this might because the differential equation I set up is stiff, and I need an implicit solver for that. However, when I replace the default Tsit5() solver with Rosenbrock23(), I get a stack-overflow error. Would you help me with this?

Thanks!
Junteng

using DifferentialEquations
using Flux
using DiffEqFlux
using Plots

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0,1.5f0)

function trueODEfunc(du,u,p,t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),saveat=t))

dudt = Chain(x -> x.^3,
             Dense(2,50,tanh),
             Dense(50,2))
n_ode(x) = neural_ode(dudt,x,tspan,Rosenbrock23(),saveat=t,reltol=1e-7,abstol=1e-9)

function predict_n_ode()
  n_ode(u0)
end
loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())

data = Iterators.repeated((), 1000)
opt = ADAM(0.1)
cb = function () #callback function to observe training
  display(loss_n_ode())
  # plot current prediction against data
  cur_pred = Flux.data(predict_n_ode())
  pl = scatter(t,ode_data[1,:],label="data")
  scatter!(pl,t,cur_pred[1,:],label="prediction")
  display(plot(pl))
end

# Display the ODE with the initial parameter values.
cb()

ps = Flux.params(dudt)
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)

BoundsError in README example

I have tried to run the example in the READM however diffeq_rd() throws a BoundsError.

This is happening with:
julia: v1.1.0
DiffEqFlux: v.0.2.0
DifferentialEquations: v6.3.0
Flux: v0.7.3

The same error occurs when implementing the examples in the Flux model-zoo with these package versions.

Following the instructions on the model-zoo page, using the following package versions:

julia: v1.0.3
DiffEqFlux: v.0.2.0

save_start=false argument ignored for neural_ode solve

Tested here: https://github.com/jessebett/DiffEqFlux.jl/blob/master/test/solver_options.jl

If you call neural_ode with save_start=false it still returns the start. I believe this needed to be overridden internally to get adjoint to work, but what is returned should respect this argument.

It is very common to run these solves only requiring the output to be saved. It would be nice if there was an option like save_only in this case would be the end of the timespan.

Forward solve fails with LinearAlgebra.Adjoint initial values

Here's a quick toy with 1D state and batch dimension.

using Flux
using DiffEqFlux
using MLDataUtils
using Distributions
using OrdinaryDiffEq

D = 1
N = 10000
BS = 100

train_x = rand(Uniform(0,5),(D,N))
train_y = train_x .^ 2
batches = batchview((train_x,train_y),BS)

nn = Chain(Dense(1,10,tanh),Dense(10,10,tanh),Dense(10,1))
tspan = (0.0f0,1.0f0)
model = x-> neural_ode(nn,x,tspan,Tsit5(),atol=1e-6,rtol=1e-6)[:,:,end]

I thought I could just create a single batch element by transposing a singleton array. However LinearAlgebra.Adjoint breaks the forward solve in neural_ode:

julia> x0 = rand(1,1)
1×1 Array{Float64,2}:
 0.49820519072539593

julia> size(x0)
(1, 1)

julia> model(x0)
Tracked 1×1 Array{Float64,2}:
 0.3671767104193268

julia> x1 = rand(1)'
1×1 LinearAlgebra.Adjoint{Float64,Array{Float64,1}}:
 0.2720790231482093

julia> size(x1)
(1, 1)

julia> model(x1)
ERROR: MethodError: Cannot `convert` an object of type Array{Float32,2} to an object of type LinearAlgebra.Adjoint{Float64,Array{Float64,1}}
Closest candidates are:
  convert(::Type{LinearAlgebra.Adjoint{T,S}}, ::LinearAlgebra.Adjoint) where {T, S} at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/adjtrans.jl:186
  convert(::Type{T}, ::T) where T<:AbstractArray at abstractarray.jl:14
  convert(::Type{T}, ::LinearAlgebra.Factorization) 

If I collect first it's fine:

julia> x2 = rand(1)'|>collect
1×1 Array{Float64,2}:
 0.22926764437395208

julia> size(x2)
(1, 1)

julia> model(x2)
Tracked 1×1 Array{Float64,2}:
 0.07379717434344342

This is surprising because LinearAlgebra.Adjoint plays well with normal solves:

u0=randn(1,1)
tspan = (0.0,1.0)
prob = ODEProblem(f,u0,tspan)
sol = solve(prob,Tsit5(),reltol=1e-8,abstol=1e-8)[:,end]
# > 1×1 Array{Float64,2}:
 1.211417535470834

u0=randn(1)'
tspan = (0.0,1.0)
prob = ODEProblem(f,u0,tspan)
sol = solve(prob,Tsit5(),reltol=1e-8,abstol=1e-8)[:,end]
# > 1×1 LinearAlgebra.Adjoint{Float64,Array{Float64,1}}:
 -1.1428152078709033

Compilation problem

I just downloaded Julia Pro, which comes with Julia 1.0.5. I added DifferentialEquations, Flux, LinearAlgebra, Plots, and DiffEqFlux.

julia> Pkg.status()
Status ~/.juliapro/JuliaPro_v1.0.5-2/environments/v1.0/Project.toml
[c52e3926] Atom v0.11.2
[aae7a2af] DiffEqFlux v0.8.0 [~/.juliapro/JuliaPro_v1.0.5-2/dev/DiffEqFlux]
[0c46a032] DifferentialEquations v6.8.0
[587475ba] Flux v0.9.0
[7073ff75] IJulia v1.20.0
[e5e0dc1b] Juno v0.7.2
[91a5bcdd] Plots v0.27.0
[37e2e46d] LinearAlgebra

However when I do: using DiffEqFlux, I get:

Info: Precompiling DiffEqFlux [aae7a2af-3d4f-5e19-a356-7da93b79d9d0]
ERROR: LoadError: LoadError: syntax: invalid function name "Morris <: GSAMethod"
Stacktrace:
[1] include at ./boot.jl:317 [inlined]
[2] include_relative(::Module, ::String) at ./loading.jl:1044
[3] include at ./sysimg.jl:29 [inlined]
[4] include(::String) at /Users/lukesh/.juliapro/JuliaPro_v1.0.5-2/packages/DiffEqSensitivity/aR25O/src/DiffEqSensitivity.jl:3
[5] top-level scope at none:0
[6] include_relative(::Module, ::String) at /Applications/JuliaPro-1.0.5-2.app/Contents/Resources/julia/Contents/Resources/julia/lib/julia/sys.dylib:?
[7] include(::Module, ::String) at /Applications/JuliaPro-1.0.5-2.app/Contents/Resources/julia/Contents/Resources/julia/lib/julia/sys.dylib:?
[8] top-level scope at none:2
[9] eval at ./boot.jl:319 [inlined]
[10] eval(::Expr) at ./client.jl:393
[11] top-level scope at ./none:3
in expression starting at /Users/lukesh/.juliapro/JuliaPro_v1.0.5-2/packages/DiffEqSensitivity/aR25O/src/morris_sensitivity.jl:1
in expression starting at /Users/lukesh/.juliapro/JuliaPro_v1.0.5-2/packages/DiffEqSensitivity/aR25O/src/DiffEqSensitivity.jl:17
ERROR: LoadError: Failed to precompile DiffEqSensitivity [41bf760c-e81c-5289-8e54-58b1f1f8abe2] to /Users/lukesh/.juliapro/JuliaPro_v1.0.5-2/compiled/v1.0/DiffEqSensitivity/02xYn.ji.
Stacktrace:
[1] error(::String) at /Applications/JuliaPro-1.0.5-2.app/Contents/Resources/julia/Contents/Resources/julia/lib/julia/sys.dylib:?
[2] compilecache(::Base.PkgId, ::String) at /Applications/JuliaPro-1.0.5-2.app/Contents/Resources/julia/Contents/Resources/julia/lib/julia/sys.dylib:?
[3] _require(::Base.PkgId) at /Applications/JuliaPro-1.0.5-2.app/Contents/Resources/julia/Contents/Resources/julia/lib/julia/sys.dylib:?
[4] require(::Base.PkgId) at /Applications/JuliaPro-1.0.5-2.app/Contents/Resources/julia/Contents/Resources/julia/lib/julia/sys.dylib:? (repeats 2 times)
[5] include_relative(::Module, ::String) at /Applications/JuliaPro-1.0.5-2.app/Contents/Resources/julia/Contents/Resources/julia/lib/julia/sys.dylib:?
[6] include(::Module, ::String) at /Applications/JuliaPro-1.0.5-2.app/Contents/Resources/julia/Contents/Resources/julia/lib/julia/sys.dylib:?
[7] top-level scope at none:2
[8] eval at ./boot.jl:319 [inlined]
[9] eval(::Expr) at ./client.jl:393
[10] top-level scope at ./none:3
in expression starting at /Users/lukesh/.juliapro/JuliaPro_v1.0.5-2/dev/DiffEqFlux/src/DiffEqFlux.jl:3
ERROR: Failed to precompile DiffEqFlux [aae7a2af-3d4f-5e19-a356-7da93b79d9d0] to /Users/lukesh/.juliapro/JuliaPro_v1.0.5-2/compiled/v1.0/DiffEqFlux/BdO4p.ji.
Stacktrace:
[1] compilecache(::Base.PkgId, ::String) at ./loading.jl:1203
[2] _require(::Base.PkgId) at ./loading.jl:960
[3] require(::Base.PkgId) at ./loading.jl:858
[4] require(::Module, ::Symbol) at ./loading.jl:853

Trouble running the Neural ODE example on GPU

First off, awesome paper, thank you!

When I try to run the NeuralODE example on CPU, it works great. However, when I try to switch it over to the GPU using the approach described in the paper i.e., using CuArrays and x->neural_ode(gpu(dudt),gpu(x),tspan,BS3(),saveat=0.1), I see:

ERROR: LoadError: MethodError: Base._reshape(::CuArray{Float32,1}, ::Tuple{Int64}) is ambiguous. Candidates:
  _reshape(A::GPUArrays.GPUArray{T,N} where N, dims::Tuple{Vararg{Int64,N}} where N) where T in GPUArrays at /home/brookhart/.julia/packages/GPUArrays/t8tJB/src/abstractarray.jl:230
  _reshape(A::GPUArrays.GPUArray{T,1}, dims::Tuple{Integer}) where T in GPUArrays at /home/brookhart/.julia/packages/GPUArrays/t8tJB/src/abstractarray.jl:236
  _reshape(parent::CuArray, dims::Tuple{Vararg{Int64,N}} where N) in CuArrays at /home/brookhart/.julia/packages/CuArrays/PD3UJ/src/array.jl:106
  _reshape(v::AbstractArray{T,1} where T, dims::Tuple{Int64}) in Base at reshapedarray.jl:167
Possible fix, define
  _reshape(::CuArray{T,1}, ::Tuple{Int64})
Stacktrace:
 [1] reshape(::CuArray{Float32,1}, ::Tuple{Int64}) at ./reshapedarray.jl:112
 [2] (::getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}})(::TrackedArray{…,CuArray{Float32,1}}) at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/utils.jl:14
 [3] #mapleaves#37(::IdDict{Any,Any}, ::Function, ::getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}, ::TrackedArray{…,CuArray{Float32,1}}) at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:27
 [4] #mapleaves at ./none:0 [inlined]
 [5] #38 at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:27 [inlined]
 [6] _broadcast_getindex_evalf(::getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}}, ::TrackedArray{…,CuArray{Float32,1}}) at ./broadcast.jl:578
 [7] _broadcast_getindex at ./broadcast.jl:551 [inlined]
 [8] (::getfield(Base.Broadcast, Symbol("##19#20")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}},Tuple{Tuple{TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}},typeof(tanh)}}}})(::Int64) at ./broadcast.jl:953
 [9] ntuple(::getfield(Base.Broadcast, Symbol("##19#20")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}},Tuple{Tuple{TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}},typeof(tanh)}}}}, ::Val{3}) at ./tuple.jl:161
 [10] copy at ./broadcast.jl:953 [inlined]
 [11] materialize at ./broadcast.jl:753 [inlined]
 [12] mapchildren(::Function, ::Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}) at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:13
 [13] #mapleaves#37(::IdDict{Any,Any}, ::Function, ::getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}, ::Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}) at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:27
 [14] (::getfield(Flux, Symbol("#kw##mapleaves")))(::NamedTuple{(:cache,),Tuple{IdDict{Any,Any}}}, ::typeof(mapleaves), ::Function, ::Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}) at ./none:0
 [15] #38 at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:27 [inlined]
 [16] _broadcast_getindex_evalf(::getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}}, ::Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}) at ./broadcast.jl:578
 [17] _broadcast_getindex at ./broadcast.jl:551 [inlined]
 [18] (::getfield(Base.Broadcast, Symbol("##19#20")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}},Tuple{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}}})(::Int64) at ./broadcast.jl:953
 [19] ntuple(::getfield(Base.Broadcast, Symbol("##19#20")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}},Tuple{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}}}, ::Val{3}) at ./tuple.jl:161
 [20] copy at ./broadcast.jl:953 [inlined]
 [21] materialize at ./broadcast.jl:753 [inlined]
 [22] mapchildren(::Function, ::Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}) at /home/brookhart/.julia/packages/Flux/8XpDt/src/layers/basic.jl:28
 [23] #mapleaves#37(::IdDict{Any,Any}, ::Function, ::getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}, ::Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}) at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:27
 [24] mapleaves at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:26 [inlined]
 [25] restructure at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/utils.jl:12 [inlined]
 [26] (::getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}})(::CuArray{Float32,1}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::Float32) at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/neural_de.jl:6
 [27] ODEFunction at /home/brookhart/.julia/packages/DiffEqBase/PvfXM/src/diffeqfunction.jl:106 [inlined]
 [28] initialize!(::OrdinaryDiffEq.ODEIntegrator{BS3,true,CuArray{Float32,1},Float32,CuArray{Float32,1},Float32,Float32,Float32,Array{CuArray{Float32,1},1},ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,CuArray{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},BS3,OrdinaryDiffEq.InterpolationData{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArray{Float32,1},1},Array{Float32,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.BS3Cache{CuArray{Float32,1},CuArray{Float32,1},CuArray{Float32,1},OrdinaryDiffEq.BS3ConstantCache{Float32,Float32}}}},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.BS3Cache{CuArray{Float32,1},CuArray{Float32,1},CuArray{Float32,1},OrdinaryDiffEq.BS3ConstantCache{Float32,Float32}},OrdinaryDiffEq.DEOptions{Float32,Float32,Float32,Float32,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float32,DataStructures.LessThan},DataStructures.BinaryHeap{Float32,DataStructures.LessThan},Nothing,Nothing,Int64,Array{Float32,1},Float64,Array{Float32,1}},CuArray{Float32,1},Float32}, ::OrdinaryDiffEq.BS3Cache{CuArray{Float32,1},CuArray{Float32,1},CuArray{Float32,1},OrdinaryDiffEq.BS3ConstantCache{Float32,Float32}}) at /home/brookhart/.julia/packages/OrdinaryDiffEq/6mZB8/src/perform_step/low_order_rk_perform_step.jl:39
 [29] #__init#479(::Float64, ::Array{Float32,1}, ::Array{Float32,1}, ::Nothing, ::Bool, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::Bool, ::Bool, ::Float32, ::Bool, ::Rational{Int64}, ::Nothing, ::Nothing, ::Int64, ::Rational{Int64}, ::Int64, ::Int64, ::Rational{Int64}, ::Bool, ::Int64, ::Nothing, ::Nothing, ::Int64, ::Float32, ::Float32, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,CuArray{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /home/brookhart/.julia/packages/OrdinaryDiffEq/6mZB8/src/solve.jl:312
 [30] (::getfield(DiffEqBase, Symbol("#kw##__init")))(::NamedTuple{(:saveat,),Tuple{Float64}}, ::typeof(DiffEqBase.__init), ::ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,CuArray{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at ./none:0
 [31] #__solve#478(::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}}, ::Function, ::ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,CuArray{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /home/brookhart/.julia/packages/OrdinaryDiffEq/6mZB8/src/solve.jl:6
 [32] #__solve at ./none:0 [inlined] (repeats 5 times)
 [33] #solve#425(::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}}, ::Function, ::ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,CuArray{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at /home/brookhart/.julia/packages/DiffEqBase/PvfXM/src/solve.jl:39
 [34] #solve at ./none:0 [inlined]
 [35] #_forward#17 at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/layers.jl:54 [inlined]
 [36] #_forward at ./none:0 [inlined]
 [37] #track#1 at /home/brookhart/.julia/packages/Flux/8XpDt/src/tracker/Tracker.jl:51 [inlined]
 [38] #track at ./none:0 [inlined]
 [39] #diffeq_adjoint#16 at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/layers.jl:50 [inlined]
 [40] (::getfield(DiffEqFlux, Symbol("#kw##diffeq_adjoint")))(::NamedTuple{(:saveat,),Tuple{Float64}}, ::typeof(diffeq_adjoint), ::TrackedArray{…,CuArray{Float32,1}}, ::ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,TrackedArray{…,CuArray{Float32,1}},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at ./none:0
 [41] #neural_ode#23(::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}}, ::Function, ::Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}, ::CuArray{Float32,1}, ::Tuple{Float32,Float32}, ::BS3) at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/neural_de.jl:8
 [42] (::getfield(DiffEqFlux, Symbol("#kw##neural_ode")))(::NamedTuple{(:saveat,),Tuple{Float64}}, ::typeof(neural_ode), ::Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}, ::CuArray{Float32,1}, ::Tuple{Float32,Float32}, ::BS3) at ./none:0
 [43] (::getfield(Main, Symbol("##5#6")))(::Array{Float32,1}) at /home/brookhart/Documents/Kaggle/titanic/kernel.jl:20
 [44] top-level scope at none:0
 [45] include at ./boot.jl:326 [inlined]
 [46] include_relative(::Module, ::String) at ./loading.jl:1038
 [47] include(::Module, ::String) at ./sysimg.jl:29
 [48] exec_options(::Base.JLOptions) at ./client.jl:267
 [49] _start() at ./client.jl:43

This matches the error seen here: https://github.com/JuliaGPU/CuArrays.jl/issues/161

I'm using Julia 1.1, latest releases of all packages, CUDA 10 on a GTX 1070Ti on Ubuntu 18.04

Thanks!

BoundsError in README example

I have tried to run the example in the README however diffeq_rd() throws a BoundsError error message attached.

This is happening with:
julia: v1.1.0
DiffEqFlux: v.0.2.0
DifferentialEquations: v6.3.0
Flux: v0.7.3

The same error occurs when implementing the examples in the Flux model-zoo with the above versions.

Following the instructions on the model-zoo page, using the following package versions:

julia: v1.0.3
DiffEqFlux: v.0.2.0
DifferentialEquations: v5.3.1
Flux: v0.7.2

Both the DiffEqFlux readme example and the Flux model-zoo examples work.
As far as I can tell it is a problem with the julia version however I am very new to the language and so am not sure.

DiffEqFlux v0.4.0 will not pre-compile with Flux 0.10.0 in Julia 1.3

This issue is superficially similar to #81 but the resolution there was to use a newer Julia. However, this happens in the latest available stable Julia (v1.3).

Problem
In julia 1.3 on MacOSX running
Pkg.add("DiffEqFlux") will add DiffEqFlux v0.4.0 and Flux 0.10

Flux will precompile but DiffEqFLux will not.
Differential Equations also has a problem.

platform
Julia v1.3 install on mac os.
using Pkg.add to install .
Status ~/.julia/environments/v1.3/Project.toml
[aae7a2af] DiffEqFlux v0.4.0
[0c46a032] DifferentialEquations v6.6.0
[587475ba] Flux v0.10.0
[7073ff75] IJulia v1.20.2
[91a5bcdd] Plots v0.28.4
[9f7883ad] Tracker v0.2.6

Expected outcome
In Julia 1.2, which installs an older (0.9.0) Flux package these errors do not occur and one can use the packages.

Actual Outcome

Errors thrown by DiffEqFlux v0.4.0 precompile
Warnings thrown by DifferentialEquations v6.6.0

Regression
both warnings and errors refer to Missing Tracker dependency. Addititionally some of the messages seem to say this dependency is supposed to be supplied by the Flux Package.

However the Flux news.md says that Tracker Dependency has been removed in favor or Zygote in Flux 0.10.0 perhaps this has something to do with why it is missing?

I read an older closed Flux issue on Tracker dependency (Flux.jl#695 ) that applies only to earlier revisions of Flux and the fix there was to add Tracker package. But I did this and it does not fix the problem.

I filed a related issue on Flux.jl #975 but I have come to suspect the issue is that the problem is in DiffEqFlux expecting something that is no longer contained in Flux.jl

Workaround
I'm not sure if I could revert this to Flux 0.9.0 (and I don't know how to do that ) but even if I did this is not as desirable since the functor macro is not available in that version of Flux

Code to reproduce
100% reproducible
Add the packages listed above to julia 1.3
then import DiffEqFlux or DifferentialEquations

Results

julia> using  DiffEqFlux
[ Info: Precompiling DiffEqFlux [aae7a2af-3d4f-5e19-a356-7da93b79d9d0]
ERROR: LoadError: LoadError: UndefVarError: Tracker not defined
Stacktrace:
 [1] include at ./boot.jl:328 [inlined]
 [2] include_relative(::Module, ::String) at ./loading.jl:1105
 [3] include at ./Base.jl:31 [inlined]
 [4] include(::String) at /Users/cems/.julia/packages/DiffEqSensitivity/DI6VG/src/DiffEqSensitivity.jl:3
 [5] top-level scope at /Users/cems/.julia/packages/DiffEqSensitivity/DI6VG/src/DiffEqSensitivity.jl:14
 [6] include at ./boot.jl:328 [inlined]
 [7] include_relative(::Module, ::String) at ./loading.jl:1105
 [8] include(::Module, ::String) at ./Base.jl:31
 [9] top-level scope at none:2
 [10] eval at ./boot.jl:330 [inlined]
 [11] eval(::Expr) at ./client.jl:425
 [12] top-level scope at ./none:3
in expression starting at /Users/cems/.julia/packages/DiffEqSensitivity/DI6VG/src/adjoint_sensitivity.jl:1
in expression starting at /Users/cems/.julia/packages/DiffEqSensitivity/DI6VG/src/DiffEqSensitivity.jl:14
ERROR: LoadError: Failed to precompile DiffEqSensitivity [41bf760c-e81c-5289-8e54-58b1f1f8abe2] to /Users/cems/.julia/compiled/v1.3/DiffEqSensitivity/02xYn_ml8Hi.ji.
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] compilecache(::Base.PkgId, ::String) at ./loading.jl:1283
 [3] _require(::Base.PkgId) at ./loading.jl:1024
 [4] require(::Base.PkgId) at ./loading.jl:922
 [5] require(::Module, ::Symbol) at ./loading.jl:917
 [6] include at ./boot.jl:328 [inlined]
 [7] include_relative(::Module, ::String) at ./loading.jl:1105
 [8] include(::Module, ::String) at ./Base.jl:31
 [9] top-level scope at none:2
 [10] eval at ./boot.jl:330 [inlined]
 [11] eval(::Expr) at ./client.jl:425
 [12] top-level scope at ./none:3
in expression starting at /Users/cems/.julia/packages/DiffEqFlux/t2FwV/src/DiffEqFlux.jl:3
ERROR: Failed to precompile DiffEqFlux [aae7a2af-3d4f-5e19-a356-7da93b79d9d0] to /Users/cems/.julia/compiled/v1.3/DiffEqFlux/BdO4p_ml8Hi.ji.
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] compilecache(::Base.PkgId, ::String) at ./loading.jl:1283
 [3] _require(::Base.PkgId) at ./loading.jl:1024
 [4] require(::Base.PkgId) at ./loading.jl:922
 [5] require(::Module, ::Symbol) at ./loading.jl:917

A similar result from another package

julia> using DifferentialEquations
┌ Warning: Error requiring Flux from ArrayInterface:
│ UndefVarError: Tracker not defined
│ Stacktrace:
│  [1] getproperty(::Module, ::Symbol) at ./Base.jl:13
│  [2] top-level scope at /Users/cems/.julia/packages/ArrayInterface/qMMsu/src/ArrayInterface.jl:32
│  [3] eval at ./boot.jl:330 [inlined]
│  [4] eval at /Users/cems/.julia/packages/ArrayInterface/qMMsu/src/ArrayInterface.jl:1 [inlined]
│  [5] (::ArrayInterface.var"#9#18")() at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:67
│  [6] err(::ArrayInterface.var"#9#18", ::Module, ::String) at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:38
│  [7] #8 at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:66 [inlined]
│  [8] withpath(::ArrayInterface.var"#8#17", ::String) at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:28
│  [9] #7 at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:65 [inlined]
│  [10] listenpkg(::ArrayInterface.var"#7#16", ::Base.PkgId) at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:13
│  [11] macro expansion at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:64 [inlined]
│  [12] __init__() at /Users/cems/.julia/packages/ArrayInterface/qMMsu/src/ArrayInterface.jl:31
│  [13] _include_from_serialized(::String, ::Array{Any,1}) at ./loading.jl:692
│  [14] _require_search_from_serialized(::Base.PkgId, ::String) at ./loading.jl:776
│  [15] _tryrequire_from_serialized(::Base.PkgId, ::UInt64, ::String) at ./loading.jl:707
│  [16] _require_search_from_serialized(::Base.PkgId, ::String) at ./loading.jl:765
│  [17] _tryrequire_from_serialized(::Base.PkgId, ::UInt64, ::String) at ./loading.jl:707
│  [18] _require_search_from_serialized(::Base.PkgId, ::String) at ./loading.jl:765
│  [19] _require(::Base.PkgId) at ./loading.jl:1001
│  [20] require(::Base.PkgId) at ./loading.jl:922
│  [21] require(::Module, ::Symbol) at ./loading.jl:917
│  [22] eval(::Module, ::Any) at ./boot.jl:330
│  [23] eval_user_input(::Any, ::REPL.REPLBackend) at /Users/sabae/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/REPL/src/REPL.jl:86
│  [24] macro expansion at /Users/sabae/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/REPL/src/REPL.jl:118 [inlined]
│  [25] (::REPL.var"#26#27"{REPL.REPLBackend})() at ./task.jl:333
└ @ Requires ~/.julia/packages/Requires/9Jse8/src/require.jl:40
┌ Warning: Error requiring Flux from DiffEqBase:
│ UndefVarError: Tracker not defined
│ Stacktrace:
│  [1] getproperty(::Module, ::Symbol) at ./Base.jl:13
│  [2] top-level scope at /Users/cems/.julia/packages/DiffEqBase/DqkH4/src/init.jl:87
│  [3] eval at ./boot.jl:330 [inlined]
│  [4] eval at /Users/cems/.julia/packages/DiffEqBase/DqkH4/src/DiffEqBase.jl:1 [inlined]
│  [5] (::DiffEqBase.var"#410#434")() at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:67
│  [6] err(::DiffEqBase.var"#410#434", ::Module, ::String) at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:38
│  [7] #409 at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:66 [inlined]
│  [8] withpath(::DiffEqBase.var"#409#433", ::String) at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:28
│  [9] #408 at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:65 [inlined]
│  [10] listenpkg(::DiffEqBase.var"#408#432", ::Base.PkgId) at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:13
│  [11] macro expansion at /Users/cems/.julia/packages/Requires/9Jse8/src/require.jl:64 [inlined]
│  [12] __init__() at /Users/cems/.julia/packages/DiffEqBase/DqkH4/src/init.jl:85
│  [13] _include_from_serialized(::String, ::Array{Any,1}) at ./loading.jl:692
│  [14] _require_search_from_serialized(::Base.PkgId, ::String) at ./loading.jl:776
│  [15] _tryrequire_from_serialized(::Base.PkgId, ::UInt64, ::String) at ./loading.jl:707
│  [16] _require_search_from_serialized(::Base.PkgId, ::String) at ./loading.jl:765
│  [17] _require(::Base.PkgId) at ./loading.jl:1001
│  [18] require(::Base.PkgId) at ./loading.jl:922
│  [19] require(::Module, ::Symbol) at ./loading.jl:917
│  [20] eval(::Module, ::Any) at ./boot.jl:330
│  [21] eval_user_input(::Any, ::REPL.REPLBackend) at /Users/sabae/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/REPL/src/REPL.jl:86
│  [22] macro expansion at /Users/sabae/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/REPL/src/REPL.jl:118 [inlined]
│  [23] (::REPL.var"#26#27"{REPL.REPLBackend})() at ./task.jl:333
└ @ Requires ~/.julia/packages/Requires/9Jse8/src/require.jl:40

FFJORD

using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, Tracker

# Neural Network
nn = Chain(Dense(1,1,tanh))
p = Tracker.data(DiffEqFlux.destructure(nn))
DiffEqFlux.restructure(nn,p)([1.0])
tspan = Float32.((0.0, 10.0))
function cnf(du,u,p,t)
  z = @view u[1:end-1]
  m = DiffEqFlux.restructure(nn,p)
  du[1:end-1] = m(z)
  fz, back = Tracker.forward(z->m(z),z')
  e = randn(size(z)[1])
  eJ = back(e)[1]
  eJe = Tracker.data((eJ .* e)[1])
  du[end] = -eJe
end

Tracker.collectmemaybe(z)

prob = ODEProblem(cnf,nothing,tspan,nothing)

p = param(Float32[0.0, 0.0]) # Initial Parameter Vector
params = Params([p])

function predict_adjoint(x)
    diffeq_adjoint(p,prob,Tsit5(),u0=[x;false],
                   saveat=0.0:0.1:10.0,
                   sensealg=DiffEqFlux.SensitivityAlg(quad=false,
                                backsolve=true,autojacvec=true))
end

function loss_adjoint(xs)
    pz = Normal(0.0, 1.0)
    preds = [predict_adjoint(x)[:,end] for x in xs]
    z = [pred[1] for pred in preds] # TODO better slicing
    delta_logp = [pred[2] for pred in preds]

    logpz = logpdf.(pz, z)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end

opt = ADAM(0.1)

raw_data = [Float32[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);

Flux.train!(loss_adjoint, params, data, opt)

# check whether it looks standard normal
using Plots

preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];

Singular mass matrices not compatible with hardcoded adjoints

I'm trying to use a DAE as a layer in a neural net. When the DAE is specified as an ODEProblem using an ODEFunction and mass_matrix, DiffEqFlux can solve the problem but it cannot be used in a neural net using diffeq_adjoint. On the Slack channel, @ChrisRackauckas mentioned that mass matrices are currently not compatible with diffeq_adjoint. I'm reporting this issue here so it can be tracked. Thanks.

GPU + Adjoint u0

GPU CI found this issue:

adjoint mode trackedu0: Error During Test at /var/lib/gitlab-runner/builds/4wGNxdk1/0/juliadiffeq/DiffEqFlux-jl/test/neural_de_gpu.jl:106
  Got exception outside of a @test
  GPU compilation of #23(CuArrays.CuKernelState, CUDAnative.CuDeviceArray{Float32,1,CUDAnative.AS.Global}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(identity),Tuple{Base.Broadcast.Extruded{Array{Float32,1},Tuple{Bool},Tuple{Int64}}}}) failed
  KernelError: passing and using non-bitstype argument
  
  Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(identity),Tuple{Base.Broadcast.Extruded{Array{Float32,1},Tuple{Bool},Tuple{Int64}}}}.
  That type is not isbits, and such arguments are only allowed when they are unused by the kernel.  .args is of type Tuple{Base.Broadcast.Extruded{Array{Float32,1},Tuple{Bool},Tuple{Int64}}} which is not isbits.
      .1 is of type Base.Broadcast.Extruded{Array{Float32,1},Tuple{Bool},Tuple{Int64}} which is not isbits.
        .x is of type Array{Float32,1} which is not isbits.

which was marked broken for now

a9997ce

but should get fixed ASAP.

Scalar saveat with neural ODE has a type issue

MWE:

using OrdinaryDiffEq, StochasticDiffEq, Flux, DiffEqFlux

x = Float32[2.; 0.]
tspan = (0.0f0,25.0f0)
dudt = Chain(Dense(2,50,tanh),Dense(50,2))
Flux.back!(sum(neural_ode(dudt,x,tspan,Tsit5(),saveat=0.1))) # Fails
Flux.back!(sum(neural_ode(dudt,x,tspan,Tsit5(),saveat=0.0:0.1:25.0))) # Works

From the MWE you can see that the workaround is simple, but annoying. It seems like a Float32 handling issue in *DiffEq.

"MethodError: no method matching gemm!" error when using Rosenbrock23 with diffeq_rd.

Hi, thanks for the great package!

I found a problem when combining diffeq_rd and the Rosenbrock23 solver.
If I run the Lotka-Volterra basic example provided named "Train an ODE to satisfy an objective"
(https://github.com/FluxML/model-zoo/blob/da4156b4a9fb0d5907dcb6e21d0e78c72b6122e0/other/diffeq/ode.jl) everything works as expected.

I replace Tsit5() with Rosenbrock23(autodiff=false) in the function predict_rd on line 30 and it still works.

Now, I increase the number of variables from 2 to 30 in the differential equations function by changing the lines marked with <<<<<<< here below (note that the additional variables do not affect the solution)

using Flux, DiffEqFlux, DifferentialEquations, Plots

## Setup ODE to optimize
function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
  du[3:end] .= 0e0 # <<<<< ADD
end
# u0 = [1.0,1.0] # <<<<< REMOVE
u0 = ones(30) # <<<<< ADD
tspan = (0.0,10.0)
p = [1.5,1.0,3.0,1.0]
prob = ODEProblem(lotka_volterra,u0,tspan,p)

# Verify ODE solution
sol = solve(prob,Tsit5())
plot(sol)

# Generate data from the ODE
sol = solve(prob,Tsit5(),saveat=0.1)
A = sol[1,:] # length 101 vector
t = 0:0.1:10.0
scatter!(t,A)

# Build a neural network that sets the cost as the difference from the
# generated data and 1

p = param([2.2, 1.0, 2.0, 0.4]) # Initial Parameter Vector
function predict_rd() # Our 1-layer neural network
  diffeq_rd(p,prob,Rosenbrock23(autodiff=false),saveat=0.1)[1,:] # <<<<<<< CHANGED
end
loss_rd() = sum(abs2,x-1 for x in predict_rd()) # loss function

# Optimize the parameters so the ODE's solution stays near 1

data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function () #callback function to observe training
  display(loss_rd())
  # using `remake` to re-create our `prob` with current parameters `p`
  display(plot(solve(remake(prob,p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
end
# Display the ODE with the initial parameter values.
cb()
Flux.train!(loss_rd, [p], data, opt, cb = cb)

These changes produce the following error

ERROR: LoadError: MethodError: no method matching gemm!(::Char, ::Char, ::Tracker.TrackedReal{Float64}, ::SubArray{Tracker.TrackedReal{Float64},2,Array{Tracker.TrackedReal{Float64},2},Tuple{UnitRange{Int64},UnitRange{Int64}},false}, ::SubArray{Tracker.TrackedReal{Float64},2,Array{Tracker.TrackedReal{Float64},2},Tuple{UnitRange{Int64},UnitRange{Int64}},false}, ::Tracker.TrackedReal{Float64}, ::SubArray{Tracker.TrackedReal{Float64},2,Array{Tracker.TrackedReal{Float64},2},Tuple{UnitRange{Int64},UnitRange{Int64}},false})
Closest candidates are:
  gemm!(::AbstractChar, ::AbstractChar, ::Float64, ::Union{AbstractArray{Float64,1}, AbstractArray{Float64,2}}, ::Union{AbstractArray{Float64,1}, AbstractArray{Float64,2}}, ::Float64, ::Union{AbstractArray{Float64,1}, AbstractArray{Float64,2}}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/blas.jl:1111
  gemm!(::AbstractChar, ::AbstractChar, ::Float32, ::Union{AbstractArray{Float32,1}, AbstractArray{Float32,2}}, ::Union{AbstractArray{Float32,1}, AbstractArray{Float32,2}}, ::Float32, ::Union{AbstractArray{Float32,1}, AbstractArray{Float32,2}}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/blas.jl:1111
  gemm!(::AbstractChar, ::AbstractChar, ::Complex{Float64}, ::Union{AbstractArray{Complex{Float64},1}, AbstractArray{Complex{Float64},2}}, ::Union{AbstractArray{Complex{Float64},1}, AbstractArray{Complex{Float64},2}}, ::Complex{Float64}, ::Union{AbstractArray{Complex{Float64},1}, AbstractArray{Complex{Float64},2}}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/blas.jl:1111
  ...

Now, if I go back from Rosenbrock23(autodiff=false) to Tsit5() the code can handle all the 30 variables and produces the same results as the 2-variables case (as expected).

To recap:

  • 2 variables + Tsit5: OK
  • 2 variables + Rosenbrock23: OK
  • 30 variables + Rosenbrock23: FAIL (gemm! error)
  • 30 variables + Tsit5: OK

Other info:

  • Same problem with TRBDF2() and in general with solvers for stiff ODEs.
  • It seems to work fine up to 16 variables (i.e. with 17 I get the gemm! error).
  • Found the same behaviour in more complex set-ups.

What am I doing wrong? Thanks in advance.


My system:
Linux 4.15.0-54-generic #58~16.04.1-Ubuntu

Julia Version 1.1.1 (2019-05-16), other packages are

  "Sundials"              => v"3.6.1"
  "Atom"                  => v"0.8.8"
  "Optim"                 => v"0.19.0"
  "Juno"                  => v"0.7.0"
  "TensorFlow"            => v"0.11.0"
  "StatsBase"             => v"0.31.0"
  "DiffEqFlux"            => v"0.5.0"
  "Flux"                  => v"0.8.3"
  "Plots"                 => v"0.25.3"
  "DifferentialEquations" => v"6.6.0"
  "ImageView"             => v"0.9.0"

Zygote compatibility for new Flux update

Updates to DiffEqFlux.jl:

using Zygote
gpu_or_cpu(x) = Array
function __init__()
    @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
        gpu_or_cpu(x::CuArrays.CuArray) = CuArrays.CuArray
        gpu_or_cpu(x::Transpose{<:Any,<:CuArrays.CuArray}) = CuArrays.CuArray
        gpu_or_cpu(x::Adjoint{<:Any,<:CuArrays.CuArray}) = CuArrays.CuArray
    end
end

New diffeq_rd():

  if typeof(u0) <: AbstractArray && !(typeof(u0) <: TrackedArray)
    if DiffEqBase.isinplace(prob)
      _prob = remake(prob,u0=convert.(recursive_bottom_eltype(p),u0),p=p)
    else
      # use TrackedArray for efficiency of the tape
      _prob = remake(prob,u0=convert(typeof(p),u0),p=p)
    end
  else # u0 is functional, ignore the change
    _prob = remake(prob,u0=u0,p=p)
  end
  solve(_prob,args...;kwargs...)
end

New 'diffeq_fd()':

function diffeq_fd(p,f,n,prob,args...;u0=prob.u0,kwargs...)
  _prob = remake(prob,u0=convert.(eltype(p),u0),p=p)
  f(solve(_prob,args...;kwargs...))
end

Zygote.@adjoint function diffeq_fd(p::AbstractVector,f,n,prob,args...;u0=prob.u0,kwargs...)
  _f = function (p)
    _prob = remake(prob,u0=convert.(eltype(p),u0),p=p)
    f(solve(_prob,args...;kwargs...))
  end
  if n === nothing
    result = DiffResults.GradientResult(p)
    ForwardDiff.gradient!(result, _f, p)
    DiffResults.value(result),Δ -> (Δ .* DiffResults.gradient(result), ntuple(_->nothing, 3+length(args))...)
  else
    y = adapt(typeof(p),zeros(n))
    result = DiffResults.JacobianResult(y,p)
    ForwardDiff.jacobian!(result, _f, p)
    DiffResults.value(result),Δ -> (DiffResults.jacobian(result)' * Δ, ntuple(_->nothing, 3+length(args))...)
  end
end

neural_de.jl for Zygote:

function neural_ode(model,x,tspan,
                    args...;kwargs...)
  dudt_(u,p,t) = model(u)
  prob = ODEProblem(dudt_,x,tspan)
  return diffeq_adjoint(prob,args...;u0=x,kwargs...)
end

function neural_ode_rd(model,x,tspan,
                       args...;kwargs...)
  dudt_(u,p,t) = model(u)
  prob = ODEProblem(dudt_,x,tspan)
  # TODO could probably use vcat rather than collect here
  solve(prob, args...; kwargs...)
end

function neural_dmsde(model,x,mp,tspan,
                      args...;kwargs...)
  dudt_(u,p,t) = model(u)
  g(u,p,t) = mp.*u
  prob = SDEProblem(dudt_,g,param(x),tspan,nothing)
  # TODO could probably use vcat rather than collect here
  solve(prob, args...; kwargs...) 
end

I'm still having trouble getting the adjoint to work correctly. I'm getting a strange error where one of the last elements of the stack trace is the same as one of the first. It might be an infinite recursive loop but I don't understand why that would be the case. But here is the current NOT WORKING code for the adjoint:

function diffeq_adjoint(prob,args...;u0=prob.u0,kwargs...)
  _prob = remake(prob,u0=u0)
  T = gpu_or_cpu(u0)
 solve(_prob,args...;kwargs...)# adapt(T, )
end

 Zygote.@adjoint function diffeq_adjoint(u0,prob,args...;backsolve=true,
                              save_start=true,save_end=true,
                              sensealg=SensitivityAlg(quad=false,backsolve=backsolve),
                              kwargs...)

  T = gpu_or_cpu(u0)
  _prob = remake(prob,u0=u0)

  # Force `save_start` and `save_end` in the forward pass This forces the
  # solver to do the backsolve all the way back to `u0` Since the start aliases
  # `_prob.u0`, this doesn't actually use more memory But it cleans up the
  # implementation and makes `save_start` and `save_end` arg safe.
  sol = solve(_prob,args...;save_start=true,save_end=true,kwargs...)

  no_start = !save_start
  no_end = !save_end
  sol_idxs = 1:length(sol)
  no_start && (sol_idxs = sol_idxs[2:end])
  no_end && (sol_idxs = sol_idxs[1:end-1])
  # If didn't save start, take off first. If only wanted the end, return vector
  only_end = length(sol_idxs) <= 1
  u = sol[sol_idxs]
  only_end && (sol_idxs = 1)
  out = only_end ? sol[end] : reduce((x,y)->cat(x,y,dims=ndims(u)),u.u)
  out, Δ -> begin
    function df(out, u, p, t, i)
      if only_end
        out[:] .= -vec(Δ)
      else
        out[:] .= -reshape(Δ, :, size(Δ)[end])[:, i]
      end
    end
    ts = sol.t[sol_idxs]
    du0, dp = adjoint_sensitivities_u0(sol,args...,df,ts;
                    sensealg=sensealg,
                    kwargs...)
    (dp', reshape(du0,size(u0)), ntuple(_->nothing, 1+length(args))...)
  end
end

I'm also a little bit confused about whether or not the destructure and restructure are still needed in utils.jl after removing the Trackers. If so, I'm unsure about how/if they should be modified.

Support for complex matrix ODEs

I'm adapt the first and simplest example in the readme, but diffeq_rd doesn't seem to be working with a complex matrix ODE.
Even something like

function myode(du,u,p,t)
    du = u;
end

or

function myode(u,p,t)
    return u;
end

doesn't seem to be working. As the initial condition, I'm using a 2x2 identity matrix with Complex{Float64} entries. t is Float64. solve() does work with no issues but diffeq_rd gives

MethodError: no method matching Array{Float64,1}(::Array{Complex{Float64},2})

BSplines to approximate gradient of data

interp(out,in...) takes the experiential results, and then the ranges of the inputs and returns the scaled interpolator and the approximate gradient at each point.

using Interpolations
function interp(out,in...)
  itp = interpolate(out,BSpline(Cubic(Line(OnGrid()))))
  itp = scale(itp, in...) 
  l = length(in)
  evalSteps =  [collect(in[d]) for d in range(1,stop=l)]
  num_in = size(evalSteps,1)
  dims = Tuple([size(i,1) for i in evalSteps])
  grad = zeros(dims...,num_in)
  for i in CartesianIndices(grad)
    tu=Tuple(i)
    if(tu[end]==1)
      t=tu[1 : end-1]
      cords = [evalSteps[d][t[d]] for d in range(1,stop=num_in)]
      g=Interpolations.gradient(itp, cords...)
      grad[t...,:]=g
    end
  end
  return itp, grad
end

DiffEqFlux.zip

Error when running example on README

using DifferentialEquations
function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end
u0 = [1.0,1.0]
tspan = (0.0,10.0)
p = [1.5,1.0,3.0,1.0]
prob = ODEProblem(lotka_volterra,u0,tspan,p)
sol = solve(prob,Tsit5())

using Flux, DiffEqFlux
p = param([2.2, 1.0, 2.0, 0.4]) # Initial Parameter Vector
params = Flux.Params([p])

function predict_adjoint() # Our 1-layer neural network
  Tracker.collect(diffeq_adjoint(p,prob,Tsit5(),saveat=0.1))
end
loss_adjoint() = sum(abs2,x-1 for x in predict_adjoint())


data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function () #callback function to observe training
  display(loss_adjoint())
  # using `remake` to re-create our `prob` with current parameters `p`
  display(plot(solve(remake(prob,p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
end

# Display the ODE with the initial parameter values.
cb()

Flux.train!(loss_adjoint, params, data, opt, cb = cb)

Error:

julia> Flux.train!(loss_adjoint, params, data, opt, cb = cb)
ERROR: BoundsError: attempt to access 0-element StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}} at index [0]
Stacktrace:
 [1] throw_boundserror(::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Tuple{Int64}) at ./abstractarray.jl:484
 [2] checkbounds at ./abstractarray.jl:449 [inlined]
 [3] getindex at ./range.jl:638 [inlined]
 [4] tstop_saveat_disc_handling at /homes/anantharaman/.julia/packages/OrdinaryDiffEq/qdrLN/src/solve.jl:408 [inlined]
 [5] #__init#334(::Float64, ::Array{Float64,1}, ::Array{Float64,1}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::Nothing, ::Bool, ::Bool, ::Float64, ::Float64, ::Float64, ::Bool, ::Bool, ::Rational{Int64}, ::Float64, ::Float64, ::Rational{Int64}, ::Int64, ::Int64, ::Int64, ::Rational{Int64}, ::Bool, ::Int64, ::Nothing, ::Nothing, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,DiffEqSensitivity.ODEAdjointSensitivityFunction{Array{Float64,1},Array{Float64,1},Array{Float64,1},Nothing,Nothing,Nothing,Nothing,Nothing,DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Nothing,Nothing,Nothing,Nothing,Array{Float64,1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},CallbackSet{Tuple{},Tuple{DiscreteCallback{getfield(DiffEqCallbacks, Symbol("##31#34")){Base.RefValue{Union{Nothing, Float64}}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#27")),Bool,Array{Float64,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float64}}},getfield(DiffEqCallbacks, Symbol("##33#36")){typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},Base.RefValue{Union{Nothing, Float64}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#27")),Bool,Array{Float64,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float64}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5, ::Array{Array{Float64,1},1}, ::Array{Float64,1}, ::Array{Any,1}, ::Type{Val{true}}) at /homes/anantharaman/.julia/packages/OrdinaryDiffEq/qdrLN/src/solve.jl:154
 [6] (::getfield(DiffEqBase, Symbol("#kw##__init")))(::NamedTuple{(:abstol, :reltol, :save_everystep, :save_start, :saveat),Tuple{Float64,Float64,Bool,Bool,Float64}}, ::typeof(DiffEqBase.__init), ::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,DiffEqSensitivity.ODEAdjointSensitivityFunction{Array{Float64,1},Array{Float64,1},Array{Float64,1},Nothing,Nothing,Nothing,Nothing,Nothing,DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Nothing,Nothing,Nothing,Nothing,Array{Float64,1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},CallbackSet{Tuple{},Tuple{DiscreteCallback{getfield(DiffEqCallbacks, Symbol("##31#34")){Base.RefValue{Union{Nothing, Float64}}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#27")),Bool,Array{Float64,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float64}}},getfield(DiffEqCallbacks, Symbol("##33#36")){typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},Base.RefValue{Union{Nothing, Float64}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#27")),Bool,Array{Float64,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float64}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5, ::Array{Array{Float64,1},1}, ::Array{Float64,1}, ::Array{Any,1}, ::Type{Val{true}}) at ./none:0 (repeats 5 times)
 [7] #__solve#333(::Base.Iterators.Pairs{Symbol,Real,NTuple{5,Symbol},NamedTuple{(:abstol, :reltol, :save_everystep, :save_start, :saveat),Tuple{Float64,Float64,Bool,Bool,Float64}}}, ::Function, ::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,DiffEqSensitivity.ODEAdjointSensitivityFunction{Array{Float64,1},Array{Float64,1},Array{Float64,1},Nothing,Nothing,Nothing,Nothing,Nothing,DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Nothing,Nothing,Nothing,Nothing,Array{Float64,1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},CallbackSet{Tuple{},Tuple{DiscreteCallback{getfield(DiffEqCallbacks, Symbol("##31#34")){Base.RefValue{Union{Nothing, Float64}}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#27")),Bool,Array{Float64,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float64}}},getfield(DiffEqCallbacks, Symbol("##33#36")){typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},Base.RefValue{Union{Nothing, Float64}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#27")),Bool,Array{Float64,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float64}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5) at /homes/anantharaman/.julia/packages/OrdinaryDiffEq/qdrLN/src/solve.jl:4
 [8] #__solve at ./none:0 [inlined]
 [9] #solve#381 at /homes/anantharaman/.julia/packages/DiffEqBase/hia0S/src/solve.jl:39 [inlined]
 [10] (::getfield(DiffEqBase, Symbol("#kw##solve")))(::NamedTuple{(:abstol, :reltol, :save_everystep, :save_start, :saveat),Tuple{Float64,Float64,Bool,Bool,Float64}}, ::typeof(solve), ::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,DiffEqSensitivity.ODEAdjointSensitivityFunction{Array{Float64,1},Array{Float64,1},Array{Float64,1},Nothing,Nothing,Nothing,Nothing,Nothing,DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Nothing,Nothing,Nothing,Nothing,Array{Float64,1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},CallbackSet{Tuple{},Tuple{DiscreteCallback{getfield(DiffEqCallbacks, Symbol("##31#34")){Base.RefValue{Union{Nothing, Float64}}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#27")),Bool,Array{Float64,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float64}}},getfield(DiffEqCallbacks, Symbol("##33#36")){typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},Base.RefValue{Union{Nothing, Float64}},getfield(DiffEqCallbacks, Symbol("##32#35")){getfield(DiffEqSensitivity, Symbol("#time_choice#20")){Array{Float64,1},Base.RefValue{Int64}},getfield(DiffEqSensitivity, Symbol("##19#21")){getfield(DiffEqFlux, Symbol("#df#27")),Bool,Array{Float64,1},Array{Float64,1},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float64}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5) at ./none:0
 [11] #adjoint_sensitivities_u0#24(::Float64, ::Float64, ::Float64, ::Float64, ::DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}}, ::Array{Float64,1}, ::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}}, ::Function, ::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}, ::Tsit5, ::getfield(DiffEqFlux, Symbol("#df#27")), ::Array{Float64,1}, ::Nothing) at /homes/anantharaman/.julia/packages/DiffEqSensitivity/GkTme/src/adjoint_sensitivity.jl:326
 [12] (::getfield(DiffEqSensitivity, Symbol("#kw##adjoint_sensitivities_u0")))(::NamedTuple{(:sensealg, :saveat),Tuple{DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Float64}}, ::typeof(DiffEqSensitivity.adjoint_sensitivities_u0), ::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}, ::Tsit5, ::Function, ::Array{Float64,1}, ::Nothing) at ./none:0 (repeats 2 times)
 [13] (::getfield(DiffEqFlux, Symbol("##24#26")){DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}},Array{Float64,1},Tuple{Tsit5},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}})(::Array{Float64,2}) at /homes/anantharaman/.julia/packages/DiffEqFlux/RFBjh/src/Flux/layers.jl:100
 [14] back_(::Tracker.Call{getfield(DiffEqFlux, Symbol("##24#26")){DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}},Array{Float64,1},Tuple{Tsit5},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}},Tuple{Tracker.Tracked{Array{Float64,1}},Nothing,Nothing,Nothing}}, ::Array{Float64,2}, ::Bool) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:35
 [15] back(::Tracker.Tracked{Array{Float64,2}}, ::Array{Float64,2}, ::Bool) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:58
 [16] #13 at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:38 [inlined]
 [17] foreach at ./abstractarray.jl:1867 [inlined]
 [18] back_(::Tracker.Call{getfield(Tracker, Symbol("##372#374")){Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},TrackedArray{…,Array{Float64,2}},Tuple{CartesianIndex{2}}},Tuple{Tracker.Tracked{Array{Float64,2}},Nothing}}, ::Float64, ::Bool) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:38
 [19] back(::Tracker.Tracked{Float64}, ::Float64, ::Bool) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:58
 [20] #13 at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:38 [inlined]
 [21] foreach at ./abstractarray.jl:1867 [inlined]
 [22] back_(::Tracker.Call{getfield(Tracker, Symbol("##344#347")){Int64},Tuple{Tracker.Tracked{Float64},Nothing}}, ::Float64, ::Bool) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:38
 [23] back(::Tracker.Tracked{Float64}, ::Float64, ::Bool) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:58
 [24] foreach at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:38 [inlined]
 [25] back_(::Tracker.Call{getfield(Tracker, Symbol("##73#74")){Tracker.TrackedReal{Float64}},Tuple{Tracker.Tracked{Float64}}}, ::Float64, ::Bool) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:38
 [26] back(::Tracker.Tracked{Float64}, ::Float64, ::Bool) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:58
 [27] #13 at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:38 [inlined]
 [28] foreach at ./abstractarray.jl:1867 [inlined]
 [29] back_(::Tracker.Call{getfield(Tracker, Symbol("##253#256")),Tuple{Tracker.Tracked{Float64},Tracker.Tracked{Float64}}}, ::Float64, ::Bool) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:38
 [30] back(::Tracker.Tracked{Float64}, ::Int64, ::Bool) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:58
 [31] #back!#15 at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:77 [inlined]
 [32] #back! at ./none:0 [inlined]
 [33] #back!#32 at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/lib/real.jl:16 [inlined]
 [34] back!(::Tracker.TrackedReal{Float64}) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/lib/real.jl:14
 [35] gradient_(::getfield(Flux.Optimise, Symbol("##14#20")){typeof(loss_adjoint)}, ::Tracker.Params) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:4
 [36] #gradient#24(::Bool, ::Function, ::Function, ::Tracker.Params) at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:164
 [37] gradient at /homes/anantharaman/.julia/packages/Tracker/RRYy6/src/back.jl:164 [inlined]
 [38] macro expansion at /homes/anantharaman/.julia/packages/Flux/qXNjB/src/optimise/train.jl:71 [inlined]
 [39] macro expansion at /homes/anantharaman/.julia/packages/Juno/TfNYn/src/progress.jl:124 [inlined]
 [40] #train!#12(::getfield(Main, Symbol("##59#60")), ::Function, ::Function, ::Tracker.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::ADAM) at /homes/anantharaman/.julia/packages/Flux/qXNjB/src/optimise/train.jl:69
 [41] (::getfield(Flux.Optimise, Symbol("#kw##train!")))(::NamedTuple{(:cb,),Tuple{getfield(Main, Symbol("##59#60"))}}, ::typeof(Flux.Optimise.train!), ::Function, ::Tracker.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::ADAM) at ./none:0
 [42] top-level scope at none:0

Status:

    Status `~/Work/Training/Project.toml`
  [aae7a2af] DiffEqFlux v0.5.2
  [0c46a032] DifferentialEquations v6.6.0
  [587475ba] Flux v0.8.3
  [f6369f11] ForwardDiff v0.10.3
  [1dea7af3] OrdinaryDiffEq v5.12.0
  [37e2e3b7] ReverseDiff v0.3.1
  [c3572dad] Sundials v3.6.1

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.