Giter Club home page Giter Club logo

Comments (2)

Vaibhavdixit02 avatar Vaibhavdixit02 commented on June 21, 2024

As per the DataLoader documentation https://fluxml.ai/Flux.jl/v0.10/data/dataloader/#Flux.Data.DataLoader, passing only a single array is suited for unsupervised learning.

In your code from what I can understand (it would be helpful to see how u_vals is created to be sure) your 3rd dimension in u_vals is the target, hence the Ys in the documentation linked above. So you probably need to do something like

train_loader = Flux.Data.DataLoader((u_vals[:,1, :],  reshape(u_vals, size(u_vals, 1), size(u_vals, 2)*size(u_vals, 3) ), batchsize = b_s, shuffle = false)

from optimization.jl.

sentz2 avatar sentz2 commented on June 21, 2024

u_vals is (for this specific example) of dimension 3 x 51 x 20. The ODE system is 3-dimensional, the solution is saved at 51 values of t, and solutions are generated for 20 initial conditions. It was created by the following:

n_samples = 20
ic = randn(n_samples, 3)

prob = ODEProblem(lorenz!, ic[1,:], tspan)
prob_func(prob, i, repeat) = remake(prob, u0 = ic[i,:])
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)

sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), saveat=t_vals, trajectories = n_samples)
u_vals = Array(sim[1:end])

I made the following changes to my code after reading your suggestion:

function predict_traj(pp, init_c)
    Array(solve(prob_nn, Tsit5(), u0 = init_c, p = pp, saveat = t_vals));
end

function loss_batch(pp, batch_ic, batch_traj)
    N_ic = size(batch_traj)[3]
    sum_loss = 0.0;
    for k in 1:N_ic
        pred = predict_traj(pp, batch_ic[:,k])
        sum_loss += sum(abs2, batch_traj[:,:,k] .- pred)
    end
    sum_loss
end

b_s = 5
train_loader = Flux.Data.DataLoader((u_vals[:,1,:], u_vals), batchsize = b_s, shuffle = false)

numEpochs = 100;
optfun = OptimizationFunction((θ, p, batch_ic, batch_traj) -> loss_batch(θ, batch_ic, batch_traj),
                              Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, p_M)
using IterTools: ncycle
res1 = Optimization.solve(optprob, Optimisers.ADAM(0.05), ncycle(train_loader, numEpochs),
                          callback = callback)

This trains successfully. There is still some redundancy in passing both u_vals[:, 1, :] and u_vals but this is an improvement.

I suppose it makes more sense to split u_vals into two arrays like ic = u_vals[:,1,:]; target = u_vals[:,2:end,:]. My including the initial conditions when computing the loss is simply an artifact of following DiffEqFlux tutorials/examples.

Nevertheless, I wonder whether it would be a good idea for Optimization to be compatible with a DataLoader with only one input array.

from optimization.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.