Giter Club home page Giter Club logo

fluxtraining.jl's People

Contributors

awadell1 avatar carlolucibello avatar christiangnrd avatar darsnack avatar expandingman avatar github-actions[bot] avatar kronosthelate avatar lorenzoh avatar moelf avatar nrhodes avatar rejuvyesh avatar romeov avatar timholy avatar touchesir avatar visr avatar vnegi10 avatar yuehhua avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

fluxtraining.jl's Issues

Simpler `Learner` API

Currently, Learner takes 4 positional arguments plus any number of callbacks as positional arguments. Not the cleanest API, so an additional method, more concise method should be added to construct Learners.

This new method could dispatch to the current method, but should be favored in the documentation. Implementation looks something like:

function Learner(model, lossfn; optim=ADAM(), data=nothing, callbacks=[], kwargs...)
    return Learner(model, data, optim, lossfn, callbacks...; kwargs...)
end

This also brings consistency with FastAI.tasklearner, which already takes callbacks as a keyword argument.

Checkpointer for best checkpoint only

Motivation and description

Currently, we already have the Checkpointer struct, which saves the model every epoch.
However, this can blow up in file size, and usually we are only interested in the best, and maybe last checkpoints.

Possible Implementation

From epoch two, we have to consider the (logical) checkpoints (prev, current, best) .
If prev != best, delete prev.
If metric(current) > metric(best), delete path[best].

We could implement this by making the Checkpointer callback mutable with a Dict Dict{Symbol, String}(:prev=>"", :current=>"", :best=>"").
Then, implementing the above logic would roughly be

function update!(cb::Checkpointer, new_path, new_metric; ord::Base.Ordering=Base.Order.Forward)
  (cb.models[:latest] != cb.paths[:best]) && rm(cb.paths[:latest])

  cb.models[:latest] = (;path=>new_path, metric=>new_metric)

  if lt(ord, cb.models[:best].metric, cb.models[:latest].metric))
      rm(cb.models[:best].path)
      cb.models[:best] = cb.models[:latest]
  end
end

Hard error on EarlyStopping()

Cheers. When training a model with FluxTraining.fit!(learner, epochs) and an early stop condition is met, I am having a hard error that causes the Julia script to be teminated, which prevents execution of code lines placed after the fit! command. I believe this is unintended behavior, please kindly verify. Thanks in advance.

Code is as follows (early stop parameters purposedly set to small numbers):

ms = [accuracy,
      t.Metric(LibML.IoU, device=gpu, name="IoU"),
]

cbs = [ToGPU(),
       StopOnNaNLoss(),
       Checkpointer(modelsfolder),
       EarlyStopping(1),
       EarlyStopping(NumberSinceBest(1)),
       EarlyStopping(Threshold(0.5)),
       Metrics(ms...),
       LogMetrics(TensorBoardBackend(tbfolder)),
       ]

learner = t.Learner(model, lossFunction;
                    data=(trainset, validset),
                    optimizer=modelOptimizer,
                    callbacks=cbs,
)

epochs = 100
FluxTraining.fit!(learner, epochs)
@info "project finished"

Error message as follows:

ERROR: CancelFittingException("Stop triggered by EarlyStopping.Patience(1) stopping criterion. ")
Stacktrace:
 [1] on(::FluxTraining.Events.EpochEnd, phase::ValidationPhase, cb::EarlyStopping, learner::FluxTraining.Protected{Learner})
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/earlystopping.jl:72
 [2] _on(e::FluxTraining.Events.EpochEnd, p::ValidationPhase, cb::EarlyStopping, learner::Learner)
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/callback.jl:254
 [3] handle(runner::FluxTraining.LinearRunner, event::FluxTraining.Events.EpochEnd, phase::ValidationPhase, learner::Learner)
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/execution.jl:12
 [4] (::FluxTraining.var"#handlefn#81"{Learner, ValidationPhase})(e::FluxTraining.Events.EpochEnd)
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:102
 [5] runepoch(epochfn::FluxTraining.var"#71#72"{…}, learner::Learner, phase::ValidationPhase)
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:106
 [6] epoch!
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:22 [inlined]
 [7] fit!(learner::Learner, nepochs::Int64, ::Tuple{MLUtils.DataLoader{…}, MLUtils.DataLoader{…}})
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:169
 [8] fit!(learner::Learner, nepochs::Int64)
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:174
 [9] top-level scope
   @ ~/projects/pascalvoc-segmentation/8-training.jl:123
Some type information was truncated. Use `show(err)` to see complete types.

julia> 

Allow to disable callbacks for some events

I really like the standardization of logging etc. that FluxTraining comes with. However, for my applications, the logging is much too verbose. Logging the loss after every gradient step significantly slows down my training if the network is fairly small.

In those settings, it would be helpful if the default-callbacks could be told to run at a lower frequency; e.g. every N training examples or even only on at the end of each epoch. Is something like that currently supported without implementing a custom callback?

Break out Schedule

Does it make sense to break out Schedule from FluxTraining.jl? It seems like you hit upon a cool Julia package to use for scheduling, and we could use it as the base for implementing several common LR schedules. Would be nice to be able to write something like Cyclic(period = n) etc.

I can give this a shot in a repo if you think it makes sense.

Support for Transfer-Learning/Layer-Freezing

Motivation and description

A common practice in machine learning is to take a pre-trained model and fine-tune it on a particular dataset. This typically involves freezing the weights in some layers while fitting the output layer(s) on the new data.

Unfortunately, this functionally appears to be incompatible with the current implementation of the ToDevice callback based on the following code:

function on(::EpochBegin, ::Phase, cb::ToDevice, learner)
    model!(learner, cb.movemodelfn(learner.model))
end

function model!(learner, model)
    learner.model = model
    learner.params = setupoptimstate(model, learner.optimizer)
end

setupoptimstate(model, ::Flux.Optimise.AbstractOptimiser) = Flux.params(model)

setupoptimstate(model, optim) = Optimisers.setup(optim, model)

This essentially means that learner.params is set to the parameters of the full model at the start of each epoch. Thus, even if we try to freeze the layers manually with Flux.freeze!(learner.params.layers[1:end-1]), this will be undone by ToDevice.

Possible Implementation

One solution that would work with Flux's new explicit optimizers would be to create a callback to freeze layers after ToDevice is executed. An example is given below:

mutable struct LayerFreezing{F} <: FluxTraining.Callback
    accessor::F
end

function FluxTraining.stateaccess(scheduler::LayerFreezing)
    return (;params = FluxTraining.Write())
end

function FluxTraining.on(
    event::FluxTraining.EpochBegin, 
    phase::FluxTraining.AbstractTrainingPhase, 
    freezer::LayerFreezing, 
    learner)
    Flux.freeze!(freezer.accessor(learner.params))
end

FluxTraining.runafter(::LayerFreezing) = (FluxTraining.ToDevice,)

However, perhaps we should consider whether it's necessary for ToDevice to move the model to the GPU at the start of every epoch. Maybe we could extend the Callback interface to allow for some one-time setup code to run before the first epoch is executed?

TagBot trigger issue

This issue is used to trigger TagBot; feel free to unsubscribe.

If you haven't already, you should update your TagBot.yml to include issue comment triggers.
Please see this post on Discourse for instructions and more details.

Hyperparameter logging to `TBLogger` is all messed up

I am not really sure what is going on, since all the other logged values look fine. I have a simple supervised training flow with the standard phases. Here is the construction of my Learner for reference.

lossfn = Flux.Losses.logitcrossentropy

# define schedule and optimizer
es = length(trainloader)
schedule = Interpolator(Step(0.001, 0.5, [20, 10, 20]), es)
# this is a patched ADAMW not Flux.ADAMW
optim = ADAMW(0.001, (0.9, 0.999), 1e-4)

# callbacks
logger = TensorBoardBackend("tblogs")
schcb = Scheduler(LearningRate => schedule)
hlogcb = LogHyperParams(logger)
mlogcb = LogMetrics(logger)
valcb = Metrics(Metric(accuracy; phase = ValidationPhase))

# setup learner object
learner = Learner(m, lossfn;
                  data = (trainloader, valloader),
                  optimizer = optim,
                  callbacks = [schcb, ToGPU(), hlogcb, mlogcb, valcb])

This is what my learning rate log looks like:

Screen Shot 2022-04-06 at 6 52 31 PM

I'm not sure if #107 is related. Running the scheduler without FluxTraining.jl looks fine:

julia> using ParameterSchedulers

julia> s = Interpolator(Step(0.001, 0.5, [20, 10, 20]), 77000 ÷ 32)
Interpolator{Step{Float64, Vector{Int64}}, Int64, ParameterSchedulers.var"#26#27"}(Step{Float64, Vector{Int64}}(0.001, 0.5, [20, 10, 20]), 2406, ParameterSchedulers.var"#26#27"())

julia> using UnicodePlots
[ Info: Precompiling UnicodePlots [b8865327-cd53-5732-bb35-84acbb429228]

julia> t = 1:(77000 ÷ 32)*50 |> collect;

julia> lineplot(t, s.(t))
          ┌────────────────────────────────────────┐ 
    0.001 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⢹⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠘⠒⠒⠒⠒⡆⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
   0.0002 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          └────────────────────────────────────────┘ 
          ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀200000⠀ 

Improve printing during training

With the single metric accuracy, the output-tables (which I love) look like this:

Epoch 11 TrainingPhase(): 100%|███████████████████████████████████████████| Time: 0:00:00
┌───────────────┬───────┬─────────┬──────────┐
│         Phase │ Epoch │    Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │  11.0 │ 0.25969 │  0.92827 │
└───────────────┴───────┴─────────┴──────────┘
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase │ Epoch │    Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │  11.0 │ 0.26323 │  0.92731 │
└─────────────────┴───────┴─────────┴──────────┘

I suggest putting them into the same table, and making the Epoch vector of element type Int64, to make it look like this:

Epoch 11 TrainingPhase(): 100%|███████████████████████████████████████████| Time: 0:00:00
┌────────--───────┬───────┬─────────┬──────────┐
│         Phase   │ Epoch │    Loss │ Accuracy │
├──────────--─────┼───────┼─────────┼──────────┤
│ TrainingPhase   │  11   │ 0.25969 │  0.92827 │
│ ValidationPhase │  11   │ 0.26323 │  0.92731 │
└─────────────────┴───────┴─────────┴──────────┘

recurrent example for docs

Motivation and description

Dealing with recurrent networks presents a lot of questions because it works rather differently from the stateless case.

I think it would be extremely helpful to have explicit examples: one for sequence-to-sequence and one for sequence-to-one.

Possible Implementation

I might come back and contribute this, but as I'm posting this I still don't think I'm doing this the intended way...

Question regarding ProgressPrinter

Hi first off: wonderful package :)

I have some issues with the ProgressPrinter not showing up even when using the defaultcallbacks.

learner = Learner(model, loss; optimizer=opt, callbacks=[ToGPU()], usedefaultcallbacks=true)
FluxTraining.fit!(learner, epochs, (dl, val_dl)) # where dl, dl_val are both Flux.DataLoader objects

Do I need to do something specific when constructing the Learner which I have missed? From the code it seems like I would need to give it a Progress object, do I have to construct that myself? What requirements does my data-iterator have to fullfill to show up with the defaultcallbacks?

`LogHyperparams` only works if there is a `Scheduler`

LogHyperparams relies on :hyperparams being defined as a key in learner.cbstate, but this is only initialized by Scheduler. So, trying to use a hyper parameter logging callback without scheduling callback results in an error.

Scheduler applies schedules per batch by default

Documentation implies that first argument given to Schedule is epoch numbers, but the schedule is applied per batch.

Oh, and nice work with the package btw. It seems quite easy to use while still being very flexible.

Schedule callbacks to run asynchronously and in parallel

Using the callback dependency graph, it's possible to determine which callbacks access what state and which callbacks need to run before.

Hence it should be possible to run callbacks that access unrelated state in parallel or asynchronously (or both). For example, you may have an integration with a experiment tracking backend like Weights and Biases which needs to perform network requests. These could be run in the background as not to slow down the training loop.

In practice, with the information from the dependency graph and whether state access are read/write a Dagger.jl DAG could be constructed and run asynchronously. Callbacks that write state, like ToGPU would still be run synchronously.

There already exists an extension interface for how callbacks are executed in FluxTraining.CallbackExecutor. The default (and so for only) implementation performs a topological sort and executes the callbacks serially.

Metrics wraps Metric(s) in Metric(s)

Sorry for confusing topic.

Perhaps easiest to just point to the place in the code:

struct Metrics <: Callback
    metrics::Tuple
    function Metrics(metrics...)
        return new(Tuple(m isa AbstractMetric ? m : Metric(m) for m in (Loss(), metrics...)))
    end
end

Problem seems to be that Metric is not an AbstractMetric, so if one follows the example from the docstring (metrics = Metrics(Metric(Flux.mse, device = gpu), Metric(Flux.mae, device = gpu)) things will fail with some Metric is not callable as the metricfn will be a Metric due to double wrapping.

Cannot recover model saved with Checkpointer()

Cheers,

Checkpointer() saves the entire trained model with BSON. For several models, I am able to recover them with BSON.@load modeladdress model. I am facing a case I could not recover the saved checkpoint. Error message as indicated below. Any hint on how to recover the saved stuff?

Anyway, it seems advisable to change the function from saving the model to saving the outcome of Flux.setup(model).

Thanks.

ERROR: type CodeInfo has no field pure
Stacktrace:
  [1] getproperty(ci::Core.CodeInfo, s::Symbol)
    @ Base ./deprecated.jl:326
  [2] 
    @ BSON ~/.julia/packages/BSON/DOYqe/src/anonymous.jl:58
  [3] newstruct_raw(cache::IdDict{Any, Any}, T::Type, d::Dict{Symbol, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:169
  [4] (::BSON.var"#49#50")(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:184
  [5] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:92
  [6] (::BSON.var"#23#24"{IdDict{Any, Any}, Module})(x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98
  [7] applychildren!(f::BSON.var"#23#24"{IdDict{Any, Any}, Module}, x::Vector{Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:26
  [8] raise_recursive
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98 [inlined]
--- the last 3 lines are repeated 1 more time ---
 [12] newstruct_raw(cache::IdDict{Any, Any}, ::Type{Core.TypeName}, d::Dict{Symbol, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/anonymous.jl:146
 [13] (::BSON.var"#49#50")(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:184
 [14] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:92
 [15] (::BSON.var"#18#21"{IdDict{Any, Any}, Module})(x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:82
 [16] applychildren!(f::BSON.var"#18#21"{IdDict{Any, Any}, Module}, x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:19
 [17] _raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:82
 [18] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:93
 [19] (::BSON.var"#23#24"{IdDict{Any, Any}, Module})(x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98
 [20] applychildren!(f::BSON.var"#23#24"{IdDict{Any, Any}, Module}, x::Vector{Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:26
 [21] raise_recursive
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98 [inlined]
 [22] (::BSON.var"#17#20"{IdDict{Any, Any}, Module})(x::Vector{Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:80
 [23] applychildren!(f::BSON.var"#17#20"{IdDict{Any, Any}, Module}, x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:19
 [24] _raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:80
--- the last 7 lines are repeated 3 more times ---
 [46] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:93
 [47] (::BSON.var"#49#50")(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:182
 [48] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:92
 [49] (::BSON.var"#23#24"{IdDict{Any, Any}, Module})(x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98
 [50] applychildren!(f::BSON.var"#23#24"{IdDict{Any, Any}, Module}, x::Vector{Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:26
 [51] raise_recursive
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98 [inlined]
 [52] (::BSON.var"#18#21"{IdDict{Any, Any}, Module})(x::Vector{Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:82
 [53] applychildren!(f::BSON.var"#18#21"{IdDict{Any, Any}, Module}, x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:19
 [54] _raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:82
 [55] (::BSON.var"#49#50")(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:183
--- the last 8 lines are repeated 2 more times ---
 [72] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:92
 [73] (::BSON.var"#19#22"{IdDict{Any, Any}, Module})(x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:86
 [74] applychildren!(f::BSON.var"#19#22"{IdDict{Any, Any}, Module}, x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:19
 [75] _raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:86
 [76] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:93
 [77] raise_recursive
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:103 [inlined]
 [78] load
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:108 [inlined]
 [79] load(x::String)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:108
 [80] macro expansion
    @ ~/.julia/packages/BSON/DOYqe/src/BSON.jl:50 [inlined]

`Recorder` does not work with models with non-`Array` inputs.

Since Recorder tries to count the number of samples in a batch separately from those in a batch, it fails to get the batch size if the model input xs is not an array.

Possible solutions:

  1. remove the sample-counting
  2. add batchsize detection cases for other input types like Tuples and Dicts

I prefer 1. since the feature is not currently used, and it is not possible to exhaustively implement extra cases for batch size detection. Recorder should focus on counting the steps.

Allow restricting phases during which a `Metric` runs

This would allow running expensive metrics only during validation phases.

Currently, you can pass functions or a Metric to the Metrics callback:

cb = Metrics(
    accuracy,
    Metric(Flux.mse),
    Metric(expensivemetricfn)
)

This feature would add a phase argument to Metric:

cb = Metrics(
    accuracy,
    Metric(Flux.mse),
    Metric(expensivemetricfn, phase = AbstractValidationPhase)
)

In case expensivemetricfn takes a long time to evaluate, this saves time during training while giving important information during the (usually much shorter) validation phases.

Error displaying EarlyStopper

Show method for EarlyStopping throws an error:

julia> using FluxTraining

julia> FluxTraining.EarlyStopping(1)
Error showing value of type EarlyStopping:
ERROR: type EarlyStopping has no field stopper
Stacktrace:
  [1] getproperty(x::EarlyStopping, f::Symbol)
    @ Base ./Base.jl:33
  [2] show(io::IOContext{Base.TTY}, cb::EarlyStopping)
    @ FluxTraining ~/.julia/packages/FluxTraining/LfCE3/src/callbacks/earlystopping.jl:56
  [3] show(io::IOContext{Base.TTY}, #unused#::MIME{Symbol("text/plain")}, x::EarlyStopping)
    @ Base.Multimedia ./multimedia.jl:47
  [4] (::REPL.var"#38#39"{REPL.REPLDisplay{REPL.LineEditREPL}, MIME{Symbol("text/plain")}, Base.RefValue{Any}})(io::Any)
    @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:220
  [5] with_repl_linfo(f::Any, repl::REPL.LineEditREPL)
    @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:462
  [6] display(d::REPL.REPLDisplay, mime::MIME{Symbol("text/plain")}, x::Any)
    @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:213
  [7] display(d::REPL.REPLDisplay, x::Any)
    @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:225
  [8] display(x::Any)
    @ Base.Multimedia ./multimedia.jl:328
  [9] (::Media.var"#15#16"{EarlyStopping})()
    @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:28
 [10] hookless(f::Media.var"#15#16"{EarlyStopping})
    @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:14
 [11] render(#unused#::Media.NoDisplay, x::EarlyStopping)
    @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:27
 [12] render(x::EarlyStopping)
    @ Media ~/.julia/packages/Media/ItEPc/src/system.jl:160
 [13] display(#unused#::Media.DisplayHook, x::EarlyStopping)
    @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:9
 [14] display(x::Any)
    @ Base.Multimedia ./multimedia.jl:328
 [15] #invokelatest#2
    @ ./essentials.jl:708 [inlined]
 [16] invokelatest
    @ ./essentials.jl:706 [inlined]
 [17] print_response(errio::IO, response::Any, show_value::Bool, have_color::Bool, specialdisplay::Union{Nothing, AbstractDisplay})
    @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:247
 [18] (::REPL.var"#40#41"{REPL.LineEditREPL, Pair{Any, Bool}, Bool, Bool})(io::Any)
    @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:231
 [19] with_repl_linfo(f::Any, repl::REPL.LineEditREPL)
    @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:462
 [20] print_response(repl::REPL.AbstractREPL, response::Any, show_value::Bool, have_color::Bool)
    @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:229
 [21] (::REPL.var"#do_respond#61"{Bool, Bool, REPL.var"#72#82"{REPL.LineEditREPL, REPL.REPLHistoryProvider}, REPL.LineEditREPL, REPL.LineEdit.Prompt})(s::REPL.LineEdit.MIState, buf::Any, ok::Bool)
    @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:798
 [22] #invokelatest#2
    @ ./essentials.jl:708 [inlined]
 [23] invokelatest
    @ ./essentials.jl:706 [inlined]
 [24] run_interface(terminal::REPL.Terminals.TextTerminal, m::REPL.LineEdit.ModalInterface, s::REPL.LineEdit.MIState)
    @ REPL.LineEdit /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/LineEdit.jl:2441
 [25] run_frontend(repl::REPL.LineEditREPL, backend::REPL.REPLBackendRef)
    @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:1126
 [26] (::REPL.var"#44#49"{REPL.LineEditREPL, REPL.REPLBackendRef})()
    @ REPL ./task.jl:411

Thanks for the awesome package! Should I open a PR to fix this?

Record time trained, training loss, validation loss and performance

For my application, I would love to be able to record the time trained, training loss, validation loss and classification performance at a given time-interval in the training loop. But currently, the History seems only able to store number of epochs, steps, and steps in current epoch.

Would there be a way to make the History for extendable, so that users can record anything they want?

A final detail would be that I want to record these stats only after a factor increase in training time, so that when I plot e.g. training loss again a logarithmic time scale, I get somewhat evenly distributed numbers. I am not sure how to make that happen, and I do not expect it to be built in functionality. I am just mentioning it in case it would be simple enough to implement.

Quickstart tutorial broken

The example Training an image classifier currently uses the following code:

xs, ys = (
    # convert each image into h*w*1 array of floats 
    [Float32.(reshape(img, 28, 28, 1)) for img in Flux.Data.MNIST.images()],
    # one-hot encode the labels
    [Float32.(Flux.onehot(y, 0:9)) for y in Flux.Data.MNIST.labels()],
)

However,

(Project) pkg> st Flux
      Status `C:\Users\Dennis Bal\ProjectFolder\Project.toml`
  [587475ba] Flux v0.13.0

julia> using Flux

julia> Flux.Data.MNIST
ERROR: UndefVarError: MNIST not defined
Stacktrace:
 [1] getproperty(x::Module, f::Symbol)
   @ Base .\Base.jl:35
 [2] top-level scope
   @ REPL[16]:1

So the example is broken. As a side note, I think the example would do great by using MLUtils instead of DataLoaders.jl and MLDataPattern. Also, Flux imports DataLoader so no need to explicitly import it.

But I take a look at the docs and try to get started. So I make the following code, that works with Flux's base capacities:

julia> using Flux

julia> using Flux: onehotbatch, onecold

julia> using FluxTraining

julia> using MLUtils: flatten, unsqueeze

julia> using MLDatasets

julia> labels = 0:9
0:9

julia> traindata = MNIST.traindata(Float32) |> x->(unsqueeze(x[1], 3), onehotbatch(x[2], labels));

julia> size.(traindata)
((28, 28, 1, 60000), (10, 60000))

julia> trainloader = DataLoader(traindata, batchsize=128);

julia> validdata = MNIST.testdata(Float32) |> x->(unsqueeze(x[1], 3), onehotbatch(x[2], labels)); 

julia> size.(validdata)
((28, 28, 1, 10000), (10, 10000))

julia> validloader = DataLoader(validdata, batchsize=128);

julia> predict = Chain(flatten, Dense(28^2, 10))
Chain(
  MLUtils.flatten,
  Dense(784 => 10),                     # 7_850 parameters
)

julia> lossfunc(x, y) = Flux.Losses.logitbinarycrossentropy(predict(x), y)
lossfunc (generic function with 1 method)

julia> optimizer=ADAM()
ADAM(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())

julia> callbacks = [Metrics(accuracy)]
1-element Vector{Metrics}:
 Metrics(Loss(), Metric(Accuracy))

julia> learner = Learner(predict, lossfunc; optimizer, callbacks)
Learner()

At this point, I start checking loss and training with Flux's train!:

julia> lossfunc(validdata...)
0.7624986f0

julia> Flux.train!(lossfunc, Flux.params(predict), trainloader, optimizer)

julia> lossfunc(validdata...)
0.11266354f0

julia> Flux.train!(lossfunc, Flux.params(predict), trainloader, optimizer)

julia> lossfunc(validdata...)
0.08880948f0

julia> Flux.train!(lossfunc, Flux.params(predict), trainloader, optimizer)

julia> lossfunc(validdata...)
0.0801171f0

Training no problem. However, when I try to train my learner, it seems like a single float is passed to predict, and not an array:

julia> fit!(learner, 1, (traindata, validdata))
Epoch 1 TrainingPhase() ...
ERROR: MethodError: no method matching flatten(::Float32)
Closest candidates are:
  flatten(::AbstractArray) at C:\Users\usrname\.julia\packages\MLUtils\QTRw7\src\utils.jl:424  
Stacktrace:
  [1] macro expansion
    @ C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0 [inlined]     
  [2] _pullback(ctx::Zygote.Context, f::typeof(flatten), args::Float32)
    @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:9        
  [3] macro expansion
    @ C:\Users\usrname\.julia\packages\Flux\18YZE\src\layers\basic.jl:53 [inlined]
  [4] _pullback
    @ C:\Users\usrname\.julia\packages\Flux\18YZE\src\layers\basic.jl:53 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, ::Float32)
    @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
  [6] _pullback
    @ C:\Users\usrname\.julia\packages\Flux\18YZE\src\layers\basic.jl:51 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, args::Float32)
    @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
  [8] _pullback
    @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:54 [inlined]
  [9] _pullback(ctx::Zygote.Context, f::FluxTraining.var"#70#72"{FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, FluxTraining.PropDict{Any}, Learner}, args::Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
    @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
 [10] _pullback
    @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:70 [inlined]
 [11] _pullback(::Zygote.Context, ::FluxTraining.var"#73#74"{FluxTraining.var"#70#72"{FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, FluxTraining.PropDict{Any}, Learner}, Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}})
    @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
 [12] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface.jl:352       
 [13] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface.jl:75        
 [14] _gradient(f::FluxTraining.var"#70#72"{FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, FluxTraining.PropDict{Any}, Learner}, #unused#::ADAM, m::Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:70      
 [15] (::FluxTraining.var"#69#71"{Learner})(handle::FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, state::FluxTraining.PropDict{Any})
    @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:53      
 [16] runstep(stepfn::FluxTraining.var"#69#71"{Learner}, learner::Learner, phase::TrainingPhase, initialstate::NamedTuple{(:xs, :ys), Tuple{Float32, Float32}})
    @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:133     
 [17] step!
    @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:51 [inlined]
 [18] (::FluxTraining.var"#67#68"{Learner, TrainingPhase, Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}})(#unused#::Function)
    @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:24      
 [19] runepoch(epochfn::FluxTraining.var"#67#68"{Learner, TrainingPhase, Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}}, learner::Learner, phase::TrainingPhase)     
    @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:105     
 [20] epoch!
    @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:22 [inlined]
 [21] fit!(learner::Learner, nepochs::Int64, ::Tuple{Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}, Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}})
    @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:168     
 [22] top-level scope
    @ REPL[51]:1

I am completely stuck as to what goes wrong. Pointers in that regard would be appreciated, but the main issue is making the example functional, and updating the packages used to load data and the utility functions that I take from MLUtils.

To improve the reliability of this package, could doc testing be used to ensure that in the future, the documentation examples actually run?

Use Optimisers.jl

With Flux.jl 0.13 moving to use the explicit optimisers in Optimisers.jl, I think FluxTraining.jl should also use those as a default.

This would also allow easier integration with alternative ADs like, PyCallChainRules.jl, see rejuvyesh/PyCallChainRules.jl#19.

@ToucheSir can this be done in a backward-compatible way, i.e. supporting Flux v0.12 and below or does Optimisers.jl depend on Flux v0.13?

Collaborating on a FastAI port?

Hi @lorenzoh,

A group of us interested/working with DL in Julia have been trying to get a community initiative off the ground for creating a fast.ai (v2) equivalent in the Flux ecosystem. The initial plan is to port the tutorials in the fastai v2 book so that folks new to the community are able to get up-and-running quickly. This would also help to identify bugs and pain points in Flux to pass on to the core team.

Given how a) actually usable and b) close to the fast.ai Python API FluxTraining and co. are, I was wondering if this package wouldn't be a great foundation for such an initiative! I'll also note that we've connected with and received support from the Flux core team, so this may also be an opportunity to influence some development upstream :)

If you're interested, most of the discussion is happening in
https://julialang.zulipchat.com/#narrow/stream/237432-ml-ecosystem-coordination. We also have a general list of ideas for general areas of improvement in Julia ML/DL at https://github.com/JuliaCommunity/ML-Coordination-Tracker.

Thanks for your work on these packages as well! I've been trying out DataLoaders.jl with great success so far.

How not to have printing callbacks?

Is there a way of constructing a Learner without certain callbacks?

julia> Learner(predict, lossfn; callbacks = [Metrics(accuracy)]).callbacks
FluxTraining.Callbacks(FluxTraining.SafeCallback[Metrics(Loss(), Metric(Accuracy)), ProgressPrinter(), MetricsPrinter(), StopOnNaNLoss(), Recorder()], FluxTraining.LinearRunner(), {5, 5} directed simple Int64 graph, false)

julia> Learner(predict, lossfn).callbacks
FluxTraining.Callbacks(FluxTraining.SafeCallback[ProgressPrinter(), MetricsPrinter(), StopOnNaNLoss(), Recorder(), Metrics(Loss())], FluxTraining.LinearRunner(), {5, 5} directed simple Int64 graph, false)

Or at least, to remove callbacks after construction?
image

CUDA memory leak for Flux.Optimizer

(This issue has been moved here from FluxML/Flux.jl#2261)

I have a somewhat complicated training setup and have recently started encountering CUDA-out-of-memory issues which only show up after a number of epochs.

I have managed to construct a minimum working example here:

using Flux
using FastAI
using MLUtils
using FastAI.FluxTraining

function main()
    DEVICE = gpu
    model = Chain(Dense(32*32*3=>2048), Dense(2048=>6), Dense(6, 32*32*3))

    make_data_sample_test(i) = (rand(Float32, 32*32*3),
                                rand(Float32, 32*32*3))
    data = mapobs(make_data_sample_test, 1:1024)
    dl     = DataLoader(data; batchsize=32, collate=true)

    loss = Flux.Losses.logitbinarycrossentropy
    opt = Flux.Adam(3e-4)
    learner = FastAI.Learner(model, loss;
                             optimizer=opt,
                             data=(dl, dl_val),
                             callbacks=[FluxTraining.ToGPU(), ])

    for _ in 1:5
      FluxTraining.epoch!(learner, FluxTraining.TrainingPhase())
      @show length(opt.state)
    end
end

After about 50 epochs (~1 minute on my laptop), I get an error that CUDA cannot allocate any more memory.
This seems to be because in the optimizer, the state variable accumulates GPU Arrays over time.

The issue can be fixed by replacing opt = Flux.Adam() with opt = Optimizers.Adam(). However, I think we should fix the problem for the Flux optimizer, since it seems to be "officially" supported.

@DrChainsaw has suggested in the other issue that the problem is that the ToDevice callback is not applied to the optimizer parameters. However I haven't looked at the specifics, and how one would fix that. Any insights?

Switch to ParameterSchedulers.jl

ParameterSchedulers.jl started in response to the limitations of Animations.jl for hyper-parameter scheduling. The API for ParameterSchedulers.jl has started to stabilize, so maybe we can consider swapping the backend?

The following snippet is some code that worked for a previous version of FluxTraining.jl. Perhaps an enterprising user can adapt it into a complete PR? If not, I'm filing this so that I remember to come back when I have the time.

"""
    Scheduler(schedules...)

Callback for hyperparameter scheduling.
Takes a pair of hyperparameters and schedules from ParameterSchedulers.

## Example
```julia
lrschedule = Exp(0.1, 0.5)
scheduler = Scheduler(
    LearningRate => lrschedule
)
```
"""
mutable struct Scheduler <: Callback
    schedules::Dict{Type{<:HyperParameter}, ParameterSchedulers.AbstractSchedule}
    step::Int
    Scheduler(args...; kwargs...) = new(Dict(args...; kwargs...), 1)
end

Base.show(io::IO, scheduler::Scheduler) =
    print(io, "Scheduler(", join(keys(scheduler.schedules), ", "), ")")

function FluxTraining.stateaccess(scheduler::Scheduler)
    # TODO: implement proper merging of permissions
    if length(keys(scheduler.schedules)) == 0
        hpstateaccess = (;)
    else
        hpstateaccess = merge(FluxTraining.stateaccess.(keys(scheduler.schedules))...)
    end
    return (data = Read(), cbstate = (; hyperparams = Write(), history = Read()),
            hpstateaccess...)
end


function FluxTraining.init!(scheduler::Scheduler, learner)
    if !haskey(learner.cbstate, :hyperparams)
        learner.cbstate.hyperparams = ValueHistories.MVHistory()
    end
    scheduler.step = 1

    return scheduler
end


function FluxTraining.on(::StepBegin, phase::AbstractTrainingPhase, scheduler::Scheduler, learner)
    for (H, schedule) in scheduler.schedules
        value = schedule(scheduler.step)
        FluxTraining.sethyperparameter!(learner, H, value)
        push!(
            learner.cbstate.hyperparams,
            Symbol(H),
            learner.cbstate.history[phase].steps,
            value)
    end
    scheduler.step += 1
end

Read access to `Learner.model` disallowed, cannot override via `stateaccess()`

Expected Behavior

I can override default state access restrictions using the stateaccess function to read learner.model.weight within the context of FluxTraining.step!(metric::MyMetric) to do some computation during the validation phase.

What I want to do is write a custom metric that has access to learner.model.weight, e.g. something like:

function FluxTraining.step!(metric::MyMetric, learner, phase)
    if phase isa metric.P
        metric.last = l1_metric(learner.model.weight)
        OnlineStats.fit!(metric.statistic, metric.last)
    else
        metric.last = nothing
    end
end

Error

I can't get this to work and I am unsure if it is a mistake on my end or if this is a bug.

Package info

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FluxTraining = "7bf95e4d-ca32-48da-9824-f0dc5310474f"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"

System info

Pop! OS 22.04 LTS x86_64 with Julia version 1.9.0

Stacktrace

julia> include("mwe.jl")
main (generic function with 1 method)

julia> main()
Epoch 1 TrainingPhase() ...
┌─────────────────┬───────┬─────────┐
│           Phase │ Epoch │    Loss │
├─────────────────┼───────┼─────────┤
│ TrainingPhase() │   1.0 │ 0.23645 │
└─────────────────┴───────┴─────────┘
Epoch 1 ValidationPhase() ...
ERROR: FluxTraining.ProtectedException("Read access to Learner.model disallowed.")
Stacktrace:
  [1] getfieldperm(data::Learner, field::Symbol, perm::Nothing)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/protect.jl:63
  [2] getproperty(protected::FluxTraining.Protected{Learner}, field::Symbol)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/protect.jl:18
  [3] step!(metric::MyMetric{Number}, learner::FluxTraining.Protected{Learner}, phase::ValidationPhase)
    @ Main ~/code/TinnitusStimulusFitter.jl/scripts/stimuli_modeling/mwe.jl:50
  [4] on(#unused#::FluxTraining.Events.StepEnd, phase::ValidationPhase, metrics::Metrics, learner::FluxTraining.Protected{Learner})
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/metrics.jl:74
  [5] _on(e::FluxTraining.Events.StepEnd, p::ValidationPhase, cb::Metrics, learner::Learner)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/callback.jl:254
  [6] handle(runner::FluxTraining.LinearRunner, event::FluxTraining.Events.StepEnd, phase::ValidationPhase, learner::Learner)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/execution.jl:12
  [7] (::FluxTraining.var"#handlefn#82"{Learner, ValidationPhase})(e::FluxTraining.Events.StepEnd)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:129
  [8] runstep(stepfn::FluxTraining.var"#79#80"{Learner}, learner::Learner, phase::ValidationPhase, initialstate::NamedTuple{(:xs, :ys), Tuple{Matrix{Float32}, Matrix{Float32}}})
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:134
  [9] step!(learner::Learner, phase::ValidationPhase, batch::Tuple{Matrix{Float32}, Matrix{Float32}})
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:84
 [10] (::FluxTraining.var"#71#72"{Learner, ValidationPhase, MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}}})(#unused#::Function)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:24
 [11] runepoch(epochfn::FluxTraining.var"#71#72"{Learner, ValidationPhase, MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}}}, learner::Learner, phase::ValidationPhase)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:105
 [12] epoch!
    @ ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:22 [inlined]
 [13] main()
    @ Main ~/code/TinnitusStimulusFitter.jl/scripts/stimuli_modeling/mwe.jl:100
 [14] top-level scope
    @ REPL[3]:1

Minimum Working Example

using Flux
using FluxTraining
using OnlineStats
using LinearAlgebra

"""
Custom type for my metric.
This is basically a duplicate of the standard Metric.
"""
mutable struct MyMetric{T} <: FluxTraining.AbstractMetric
    statistic::OnlineStats.OnlineStat{T}
    _statistic::Any
    name::Any
    device::Any
    P::Any
    last::Union{Nothing, T}
end

"""
Outer constructor for MyMetric.
"""
function MyMetric(;
        statistic = OnlineStats.Mean(Float32),
        device = cpu,
        phase = ValidationPhase,
        name = "MyMetric")
    return MyMetric(statistic, deepcopy(statistic), name, device, phase, nothing)
end

"""
Reset MyMetric back to the initial value.
"""
function FluxTraining.reset!(metric::MyMetric{T}) where T
    metric.statistic = deepcopy(metric._statistic)
end

"""
We will use the L1 norm as an "example function"
that requires access to model weights.
"""
function l1_metric(W::Matrix)
    return norm(W, 1) / size(W, 1)
end

"""
Compute the metric by taking the L1 norm of the model weight matrix.
"""
function FluxTraining.step!(metric::MyMetric, learner, phase)
    if phase isa metric.P
        metric.last = l1_metric(learner.model.weight)
        OnlineStats.fit!(metric.statistic, metric.last)
    else
        metric.last = nothing
    end
end

function Base.show(io::IO, metric::MyMetric{T}) where {T}
    print(io, "Metric(", metric.name, ")")
end

FluxTraining.runafter(::MyMetric) = (FluxTraining.Recorder,)
FluxTraining.stepvalue(metric::MyMetric) = metric.last
FluxTraining.metricname(metric::MyMetric) = metric.name

function FluxTraining.epochvalue(metric::MyMetric)
    if isnothing(metric.last)
        nothing
    else
        OnlineStats.value(metric.statistic)
    end
end

function FluxTraining.stateaccess(::MyMetric)
    return (
        model = FluxTraining.Read(),
        params = FluxTraining.Read(),
        cbstate = (metricsstep = FluxTraining.Write(), metricsepoch = FluxTraining.Write(), history = FluxTraining.Read()),
        step = FluxTraining.Read(),
    )
end

function main()
    in_dim = 10
    out_dim = 1
    n_samples = 64
    model = Dense(in_dim => out_dim, identity; bias=false)

    X = rand(in_dim, n_samples) |> f32
    y = rand(out_dim, n_samples) |> f32
    train_dataloader = Flux.DataLoader((X, y))
    val_dataloader = deepcopy(train_dataloader)

    callbacks = [FluxTraining.Metrics(MyMetric())]
    opt_state = Flux.Adam(1f-4)

    learner = FluxTraining.Learner(model, Flux.mse; callbacks = callbacks, optimizer = opt_state)

    for i = 1:3
        FluxTraining.epoch!(learner, FluxTraining.TrainingPhase(), train_dataloader)
        FluxTraining.epoch!(learner, ValidationPhase(), val_dataloader)
    end
end

`Scheduler` causes cycle in execution DAG?

I have the following script:

lossfn = Flux.Losses.logitcrossentropy

# define schedule and optimizer
initial_lr = 0.1
schedule = Step(initial_lr, 0.5, 20)
optim = Flux.Optimiser(Momentum(initial_lr), WeightDecay(1e-3))

# callbacks
logger = TensorBoardBackend("tblogs")
schcb = Scheduler(LearningRate => schedule)
hlogcb = LogHyperParams(logger)
mlogcb = LogMetrics(logger)
valcb = Metrics(Metric(accuracy; phase = TrainingPhase, name = "train_acc"),
                Metric(accuracy; phase = ValidationPhase, name = "val_acc"))

# setup learner object
learner = Learner(m, lossfn;
                  data = (trainloader, valloader),
                  optimizer = optim,
                  callbacks = [ToGPU(), mlogcb, valcb])

Any time I add schcb to the list of callbacks passed to the Learner, I get an error from FluxTraining that there is a cycle in the DAG. This did not happen in previous versions of FluxTraining (though I haven't been able to bisect the change yet).

`ignore(f)` is deprecated

perms = Zygote.ignore() do

┌ Warning: `ignore(f)` is deprecated, use `ChainRulesCore.ignore_derivatives(f)` instead.
│   caller = _on(e::FluxTraining.Events.BackwardEnd, p::TrainingPhase, cb::ProgressPrinter, learner::Learner) at callback.jl:251
└ @ FluxTraining ~/.julia/packages/FluxTraining/iBFSd/src/callbacks/callback.jl:251

Training loop profiler

Using Events as hooks into the training loop, it's possible to create a profiler for training loops that measures the time spend executing events but also the time spent inbetween the events, i.e. in the training loop.

This would allow more easily identifying possible performance bottlenecks, like:

  • waiting on the data iterator (identifiable as time spent between the end of StepEnd and StepBegin)
  • moving data to the gpu during the step (does it matter or is it just as fast as doing it in the background asynchronously?)

Thoughts on implementation | This could be implemented as a callback, though you would need two callbacks one running before all the callbacks and one after the others (to measure callback times) which is unwieldy. This solution may also not play well with the asynchronous callback scheduler proposed in #85.
The imo better solution is to implement a callback execution context and does the timings before and after it runs the callbacks. It would wrap another callback execution context that it refers to, thus would also play nicely with the asynchronous callback scheduler as it would measure only the time spent on the synchronous part.

Interpretation | Events that specify start and stop points like StepBegin and StepEnd could be treated as a layer in the profiling stack. Possibly an existing package for visualizing flamegraphs could be reused to make sense of the profiling data.

Unnecessary softmax in accuracy?

Just stumbled on this as I was poking around:

function accuracy(y_pred, y)
    return mean(onecold(cpu(softmax(y_pred))) .== onecold(cpu(y)))
end

That softmax should not be needed as onecold afaik only does argmax + some plumbing under the hood and softmax should not change the outcome of that.

docs oddities

At the very top of this doc page https://fluxml.ai/FluxTraining.jl/dev/i/?id=documents%2Fdocs%2Fcallbacks%2Fusage.md

using FluxTraining
using FluxTraining: Callback, Read, Write, stateaccess
model, data, lossfn = nothing, (), nothing, nothing

Is that intended?

Also, if from that page I follow a couple of links and land to
https://fluxml.ai/FluxTraining.jl/dev/i/?id=documents%2FREADME.md&id=documents%2Fdocs%2Fcallbacks%2Fcustom.md&id=documents%2Fdocs%2Ftutorials%2Ftraining.md
but then i don't see any buttons for closing all those panes or for going back. Even the browser back button doesn't have any effect.

One last thing is that I don't see a link when visualizing the docstring of types/methods to jump to the source code.

Callbacks Quality of Life improvements

Collecting some nice-to-have changes to callbacks. Some of these are breaking so they should be part of the next minor release.

  • CustomCallback: Pass event and phase through to the wrapped function. Not having access to this information is unnecessarily limiting. BREAKING
  • ToDevice: Change default behavior so that every array that is in state at StepBegin is moved. Add option to only move specific keys. This is a good default when implementing custom training Phases, as the current implementations assumes there are always xs and ys (this is from before the new training loop API was introduced). BREAKING

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.