Comments (2)
As per the DataLoader
documentation https://fluxml.ai/Flux.jl/v0.10/data/dataloader/#Flux.Data.DataLoader, passing only a single array is suited for unsupervised learning.
In your code from what I can understand (it would be helpful to see how u_vals
is created to be sure) your 3rd dimension in u_vals
is the target, hence the Y
s in the documentation linked above. So you probably need to do something like
train_loader = Flux.Data.DataLoader((u_vals[:,1, :], reshape(u_vals, size(u_vals, 1), size(u_vals, 2)*size(u_vals, 3) ), batchsize = b_s, shuffle = false)
from optimization.jl.
u_vals
is (for this specific example) of dimension 3 x 51 x 20. The ODE system is 3-dimensional, the solution is saved at 51 values of t, and solutions are generated for 20 initial conditions. It was created by the following:
n_samples = 20
ic = randn(n_samples, 3)
prob = ODEProblem(lorenz!, ic[1,:], tspan)
prob_func(prob, i, repeat) = remake(prob, u0 = ic[i,:])
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), saveat=t_vals, trajectories = n_samples)
u_vals = Array(sim[1:end])
I made the following changes to my code after reading your suggestion:
function predict_traj(pp, init_c)
Array(solve(prob_nn, Tsit5(), u0 = init_c, p = pp, saveat = t_vals));
end
function loss_batch(pp, batch_ic, batch_traj)
N_ic = size(batch_traj)[3]
sum_loss = 0.0;
for k in 1:N_ic
pred = predict_traj(pp, batch_ic[:,k])
sum_loss += sum(abs2, batch_traj[:,:,k] .- pred)
end
sum_loss
end
b_s = 5
train_loader = Flux.Data.DataLoader((u_vals[:,1,:], u_vals), batchsize = b_s, shuffle = false)
numEpochs = 100;
optfun = OptimizationFunction((θ, p, batch_ic, batch_traj) -> loss_batch(θ, batch_ic, batch_traj),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, p_M)
using IterTools: ncycle
res1 = Optimization.solve(optprob, Optimisers.ADAM(0.05), ncycle(train_loader, numEpochs),
callback = callback)
This trains successfully. There is still some redundancy in passing both u_vals[:, 1, :]
and u_vals
but this is an improvement.
I suppose it makes more sense to split u_vals
into two arrays like ic = u_vals[:,1,:]; target = u_vals[:,2:end,:]
. My including the initial conditions when computing the loss is simply an artifact of following DiffEqFlux
tutorials/examples.
Nevertheless, I wonder whether it would be a good idea for Optimization
to be compatible with a DataLoader
with only one input array.
from optimization.jl.
Related Issues (20)
- OptimizationMOI 0.2.0, cannot assign Symbolics.Num to MArray HOT 7
- Downstream Compat bumps
- Error with LBFGS() with using adjoint on 3D arrays
- Error when `store_trace=true` with OptimizationEvolutionary.jl
- Issue in running OptimizationFunction HOT 13
- Latest SciMLBase + Optimization breaks precompile HOT 3
- TypeError: in keyword argument linesearch, expected Function HOT 8
- Optimiztion.jl does not precompile HOT 11
- Add trait for checking if OptimizationFunction is used for derivative based optimizers
- `MethodError: objects of type Nothing are not callable` when `lb` and `ub` are used with `NelderMead` HOT 9
- PolyesterForwardDiff not loading HOT 3
- No documentation for latest release HOT 1
- PRIMA lib errors with `AutoForwardDiff` HOT 3
- Support LBFGSB.jl HOT 12
- Include `searchdirection` in `OptimizationState` HOT 2
- The `callback` appears to be called for linesearch iterations HOT 1
- `PolyOpt` only accept functions without any extra inputs HOT 1
- US spelling preferred? HOT 1
- Augmented Lagrangian HOT 5
- Multithreading support for Optimizers like BBO
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from optimization.jl.