Giter Club home page Giter Club logo

Comments (10)

MikeInnes avatar MikeInnes commented on May 22, 2024 2

Current release Julia (I'll remove any need for a source build this week, though there'll still be a flag for type system abuse).

(v1.0) pkg> add Zygote#master IRTools#master

julia> using Zygote, BenchmarkTools

julia> function logsumexp(x::Array{Float64,1})
         A = maximum(x)
         ema = exp.(x .- A)
         sema = sum(ema)
         log(sema) + A
       end
logsumexp (generic function with 1 method)

julia> const x = rand(100);

julia> @btime logsumexp(x)
  1.515 μs (1 allocation: 896 bytes)
5.119135839482861

julia> @btime Zygote.gradient(logsumexp, x);
  4.882 μs (35 allocations: 7.61 KiB)

Almost all of the time is spent in broadcast; I imagine there's yet more that we could do to be cleverer here, though it's hard to imagine how we'd get as good as the hand written version without special knowledge.

from zygote.jl.

jekbradbury avatar jekbradbury commented on May 22, 2024 1

Zygote requires (edit: now simply says it works faster on) Julia built from source including Mike’s patches; I have a Docker image with such a build here: docker pull gcr.io/personal-204923/julia:zygote

from zygote.jl.

awf avatar awf commented on May 22, 2024 1

Just to confirm in case useful: this does happen on James's image too.

from zygote.jl.

MikeInnes avatar MikeInnes commented on May 22, 2024

We should ideally not be crashing Julia's codegen here, as opposed to just throwing an error, but this is probably easily fixed by adding the right gradient definition (of which we have very few right now; * and + with scalars is enough to check the chain rule is working :)). I'll take a look shortly.

from zygote.jl.

baggepinnen avatar baggepinnen commented on May 22, 2024

I'm experiencing a segfault with the following code on julia v0.7/Zygote v0.1. I can't really tell if it's related to this issue

using ForwardDiff, Zygote
w = randn(3,3)
x = randn(3)
f(w,x) = w*x
function loss(w,x)
    ∇xf(x) = ForwardDiff.jacobian(x->f(w,x), x)
    w -> sum(abs2.(w)) + norm(∇xf(x)) # Loss function contains nested differentiation of the model with respect to the input
end

l = loss(w,x)
g = Zygote.gradient(l,w)

from zygote.jl.

ngphuoc avatar ngphuoc commented on May 22, 2024

I modified the example in the README to add a loss function and it crashed with a sigfault too:

using Zygote: @grad, Params, gradient

W, b = rand(2, 3), rand(2);

predict(x) = W*x .+ b;

loss(x,y) = sum(abs2, y .- predict(x))

g = gradient(() -> sum(loss([1,2,3], [1, 1])), Params([W, b]))

g[W], g[b]

from zygote.jl.

MikeInnes avatar MikeInnes commented on May 22, 2024

I added a simple gradient def for maximum which fixes the original example. I really need to bring over all the definitions that Flux has.

I also just realised that our broadcast grad is not type-stable, so this benchmark comes out a bit slower than it needs to be.

from zygote.jl.

chriselrod avatar chriselrod commented on May 22, 2024

The overhead for small problems still seems high (so my tests StaticArrays aren't fast yet), but...

julia> @btime logsumexp_both(x);
  960.263 ns (3 allocations: 1.78 KiB)

julia> @btime Zygote.gradient(logsumexp, x);
  2.737 μs (35 allocations: 7.61 KiB)

julia> @btime ReverseDiff.gradient!($results, $compiled_tape, $inputs);
  15.900 μs (0 allocations: 0 bytes)

julia> @btime ForwardDiff.gradient!($res, logsumexp, x, $cfg);
  25.158 μs (10 allocations: 87.50 KiB)

This was on regular 1.0.

This is awesome to see, because it is (a) significantly faster, and (b) Zygote didn't require me to redefine logsumexp to be generic.

from zygote.jl.

MikeInnes avatar MikeInnes commented on May 22, 2024

Nice!

I did notice one of the earlier StaticArrays examples not infering properly. Feel free to open a new issue for anything like that.

from zygote.jl.

awf avatar awf commented on May 22, 2024

Very nice. Note that my initial code wasn't quite comparing apples to apples: logsumexp_both computes the function value too. For timing, we should compare Zygote.gradient to

# grad_logsumexp()
function grad_logsumexp(x::Array{Float64,1})
  A = maximum(x);
  ema = exp.(x .- A);
  sema = sum(ema);
  # l = log(sema) + A;
  return ema ./ sema;
end

from zygote.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.