Giter Club home page Giter Club logo

legolasflux.jl's Introduction

LegolasFlux

CI codecov

Note: Upgrading from LegolasFlux v0.1.x to v0.2?

The only change is an update to Legolas v0.5. Be sure to check out the guidance for updating Legolas to v0.5 along with the rest of Legolas's documentation and tour.


LegolasFlux provides some simple functionality to use Legolas.jl's extensible Arrow schemas as means to serialize Flux models similarly to using Flux's params and loadparams! (instead, we export similar functions fetch_weights and load_weights! which handle layers like BatchNorm correctly for this purpose).

The aim is to serialize only the numeric weights, not the code defining the model. This is a very different approach from e.g. BSON.jl, and hopefully much more robust. Note that in this package, we use weights to refer to the numeric arrays that are modified over the course of training a model; that includes biases as well as means and variances in e.g. BatchNorms (but not e.g. configuration settings).

With this approach, however, if you change the code such that the weights are no longer valid (e.g. add a layer), you will not be able to load back the same model.

Usage

using Flux

function make_my_model()
    return Chain(Dense(1,10), Dense(10, 10), Dense(10, 1))
end

my_model = make_my_model()
# train it? that part is optional ;)

# Now, let's save it!
using LegolasFlux

model_row = LegolasFlux.ModelV1(; weights = fetch_weights(cpu(my_model)),
                                architecture_version=1)
write_model_row("my_model.model.arrow", model_row)

# Great! Later on, we want to re-load our model weights.
fresh_model = make_my_model()

model_row = read_model_row("my_model.model.arrow")
load_weights!(fresh_model, model_row.weights)
# Now our weights have been loaded back into `fresh_model`.

We can make use of the architecture_version column to specify a version number for the architectures, in order to keep track of for which architectures the weights are valid for.

See examples/digits.jl for a larger example, which also saves out extra metadata with the model, by using a Legolas schema extension.

LegolasFlux.ModelV1

A LegolasFlux.ModelV1 is a central object of LegolasFlux. It acts as a Tables.jl-compatible row that can store the weights of a Flux model in the weights column, optionally an architecture_version (defaults to missing).

ModelV1 is not exported because downstream models likely want to define their own rows which extend the schema provided by LegolasFlux that might end up being called something similar. See the next section for more on extensibility.

Extensibility

As a Legolas.jl schema, it is meant to be extended. For example, let's say I had an MNIST classification model that I call Digits. I am very committed to reproducibility, so I store the commit_sha of my model's repo with every training run, and I also wish to save the accuracy and epoch. I might create a DigitsRow which is a schema extension of the legolas-flux.model schema:

using Legolas, LegolasFlux
using Legolas: @schema, @version
@schema "digits-model" DigitsRow
@version DigitsRowV1 > ModelV1 begin
    # re-declare this ModelV1 field as parametric for this schema as well
    weights::(<:Union{Missing,Weights})
    epoch::Union{Missing, Int}
    accuracy::Union{Missing, Float32}
    commit_sha::Union{Missing, String}
end

Now I can use a DigitsRowV1 much like LegolasFlux's ModelV1. It has the same required weights column and optional architecture_version column, as well as the additional epoch, accuracy, and commit_sha columns. As a naming convention, one might name files produced by this row as e.g. training_run.digits.model.arrow.

When writing out a DigitsRowV1, I'll pass the schema version like so

write_model_row(path, my_digits_row, DigitsRowV1SchemaVersion())

so that later, when I call read_model_row on this path, I'll get back a DigitsRowV1 instance.

Note in this example the schema is called digits.model instead of just say digits, since the package Digits might want to create other Legolas schemas as well at some point.

Check out the Legolas.jl repo to see more about how its extensible schema system works, and the example at examples/digits.jl.

legolasflux.jl's People

Contributors

ericphanson avatar ararslan avatar kleinschmidt avatar haberdashpi avatar kimlaberinto avatar

Stargazers

 avatar Heiner Spieß avatar Germán Abrevaya avatar Ilya Orson  avatar Páll Haraldsson avatar McCoy R. Becker avatar

Watchers

 avatar kolia avatar  avatar Elliott Muñoz avatar James Cloos avatar Hannah Robertson avatar Dan Hassin avatar Phillip Alday avatar Sean Bennett avatar Jérémy Autran avatar Dian Fay avatar Jarrett Revels avatar  avatar  avatar  avatar Ilan Goodman avatar Vee avatar Alex Chan avatar Eric Snyder avatar Jacob Donoghue avatar  avatar

Forkers

bparbhu

legolasflux.jl's Issues

Upgrade to Legolas 0.5

Legolas 0.5 changes how schemas are defined and used. There's a single schema defined here, ModelRow, which needs to be upgraded to use @schema/@version instead of @row.

`read_model_row` does not work for schema extensions of `ModelV1`

read_model_row always returns a ModelV1 which makes it pretty useless for anything that requires fields from a schema extensions. In Legolas <0.5, this was fine because extra fields were preserved but now they are discarded.

Instead, could LegolasFlux get the schema from the table metadata, use Legolas.record_type to get the appropriate record type, and try to construct that instead?

Save optimizer state?

This may not need any code changes, but it would be great to update the example to save out optimizer state to do resumable training. This might require waiting for https://github.com/FluxML/Optimisers.jl to mature a bit.

Or else we need to do some more complicated stuff to handle things like IdDict's in Flux.ADAM. E.g. we could map objectid's to something like UUIDs, then save those out next to the arrays, e.g. our weights object could hold a Vector{Tuple{UUID, FlatArray{T}}} instead of Vector{FlatArray{T}}. Then we could write code to map between IdDicts and another Vector{Tuple{UUID, FlatArray{T}}} (or maybe just Dict{UUID, FlatArray{T}} if Arrow will support that easily) to save out the optimizer internals.

Or... maybe there's a simpler way. Right now, we have a Vector{FlatArray{T}} for the weights, and each array only shows up once (since our fcollect uses an IdSet cache). So that is already basically a mapping between index in that array and objectid of an array. So we could re-use that to also just have another vector of the same length, where each entry is the optimizer state for each array in the weights.

[documentation] example extension does not extend ModelV1, and requires parameter for weights to serialize

The DigitRowV1 example does not exnted ModelV1. When you do add the extension naively with > ModelV1, it does not serialize properly because ModelV1 has a type parameter for weights but that has to be manually carried over

@schema "digits-model" DigitsRow
@version DigitsRowV1 > ModelV1 begin
    epoch::Union{Missing, Int}
    accuracy::Union{Missing, Float32}
    commit_sha::Union{Missing, String}
end

gives

julia> model_row = DigitsRowV1(; weights=missing)
DigitsRowV1: (weights = missing, architecture_version = missing, epoch = missing, accuracy = missing, commit_sha = missing)

julia> write_model_row(model_path, model_row)
ERROR: MethodError: no method matching zero(::Type{Union{Missing, Weights}})
Closest candidates are:
  zero(::Type{Union{Missing, T}}) where T at missing.jl:105
  zero(::Union{Type{P}, P}) where P<:Dates.Period at ~/.julia/juliaup/julia-1.8.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.8/Dates/src/periods.jl:53
  zero(::StatsBase.Histogram{T, N, E}) where {T, N, E} at ~/.julia/packages/StatsBase/XgjIN/src/hist.jl:562                                  
  ...
Stacktrace:
  [1] default(T::Type)
    @ ArrowTypes ~/.julia/packages/ArrowTypes/E6ePy/src/ArrowTypes.jl:303
  [2] ArrowTypes.ToArrow(x::Vector{Union{Missing, Weights}})
    @ ArrowTypes ~/.julia/packages/ArrowTypes/E6ePy/src/ArrowTypes.jl:353

but when you declare the type parameter on the extension, it works:

julia> @version DigitsRowV2 > ModelV1 begin
           weights::(<:Union{Missing,Weights})
           epoch::Union{Missing, Int}
           accuracy::Union{Missing, Float32}
           commit_sha::Union{Missing, String}
       end

julia> model_row = DigitsRowV2(; weights=missing)
DigitsRowV2{Missing}: (weights = missing, architecture_version = missing, epoch = missing, accuracy = missing, commit_sha = missing)                                     

julia> write_model_row(model_path, model_row)
1-element Vector{DigitsRowV2{Missing}}:
 DigitsRowV2{Missing}: (weights = missing, architecture_version = missing, epoch = missing, accuracy = missing, commit_sha = missing)

TagBot trigger issue

This issue is used to trigger TagBot; feel free to unsubscribe.

If you haven't already, you should update your TagBot.yml to include issue comment triggers.
Please see this post on Discourse for instructions and more details.

If you'd like for me to do this for you, comment TagBot fix on this issue.
I'll open a PR within a few hours, please be patient!

Nicer `show` methods

We should add nicer show methods, e.g.

Base.show(io::IO, v::LegolasFlux.Weights) = print(io, typeof(v), "(…$(length(v)) weights…)")

already makes it much cleaner when printing a row that includes weights.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.