Giter Club home page Giter Club logo

Comments (6)

mcabbott avatar mcabbott commented on June 5, 2024

That's quite a bold heading!

According to Zygote, the following function has different gradients wrt the same inputs (mathematically speaking).

I think you mean the inputs are == yet the gradients are different.

This is true, for example:

julia> Zygote.gradient(x -> real(exp(x * (1+im))), pi/2)
(-4.810477380965351,)

julia> Zygote.gradient(x -> real(exp(x * (1+im))), pi/2 + 0im)
(-4.810477380965351 - 4.810477380965351im,)

julia> pi/2 == pi/2+0im
true

That is, Zygote (really ChainRules) regards the type of the input as specifying the domain of the function, and hence the appropriate cotangent space in which the gradient lives. The fact that x -> real(exp(x * (1+im))) uses complex numbers internally is ignored, to view this as an R -> R function when x::Real.

Before ChainRules 1.0, Zygote did not do this. It regarded all numbers as living in C, and all matrices as living in C^N*M. I think this more or less fell out of how it works and Julia's type promotion rules. The fact that this function uses complex numbers internally would lead it to tell you to do gradient descent in a complex direction.

ChainRules applies such projections to almost every step. When a real number propagates forward through some complicated code for 10 steps & then gets promoted to complex for the 11th, the reverse pass projects its gradient back to real immediately, so that these 10 reverse steps involve only real numbers again.

Similar projections apply to most structured matrix types (such as Diagonal), and also to sparse arrays.

from zygote.jl.

mohamed82008 avatar mohamed82008 commented on June 5, 2024

Moving some of my discussion points over from Slack and polishing them a bit.

ChainRules projects the function's input's co-tangent to match the input's type. This can result in loss information where a dense (or otherwise structured but not sparse, e.g 1-rank matrix) gradient is projected onto a sparse structure. An example can be shown here https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/LinearAlgebra/dense.jl#L14. In my opinion, the projection should only happen at the function's output's co-tangent so the dot function's rule should be:

import ChainRulesCore: rrule, @thunk, ProjectTo, unthunk
import Zygote

mydot(x, y) = dot(x, y)
function rrule(::typeof(mydot), x::AbstractArray, y::AbstractArray)
	out = dot(x, y)
    project_out = ProjectTo(out)
    function dot_pullback(Ω̄)
        ΔΩ = project_out(unthunk(Ω̄))
        x̄ = @thunk(reshape(y .* ΔΩ', axes(x)))
        ȳ = @thunk(reshape(x .* ΔΩ, axes(y)))
        return (NoTangent(), x̄, ȳ)
    end
    return out, dot_pullback
end

instead of the current implementation:

function rrule(::typeof(dot), x::AbstractArray, y::AbstractArray)
    project_x = ProjectTo(x)
    project_y = ProjectTo(y)
    function dot_pullback(Ω̄)
        ΔΩ = unthunk(Ω̄)
        x̄ = @thunk(project_x(reshape(y .* ΔΩ', axes(x))))
        ȳ = @thunk(project_y(reshape(x .* ΔΩ, axes(y))))
        return (NoTangent(), x̄, ȳ)
    end
    return dot(x, y), dot_pullback
end

It seems that this was done on purpose to counter another issue which I am going to call it the "too much information" problem. For example, the suggested new mydot rrule will return a complex gradient for the first argument if we differentiate dot([1,2.], [3+im, 4.]).

julia> Zygote.pullback(real  mydot, [1,2], Complex.([1,2]))[2](1.0)
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im], ComplexF64[1.0 + 0.0im, 2.0 + 0.0im])

Is this wrong? Some people think so. I don't. The user clearly is mixing complex and real numbers so it's on them to project it if they only care about the real part or to explicitly declare that x is the real component only using:

julia> f(x, y) = real(mydot(Complex.(x), y))
f (generic function with 1 method)

julia> Zygote.pullback(f, [1,2], Complex.([1,2]))[2](1.0)
([1.0, 2.0], ComplexF64[1.0 + 0.0im, 2.0 + 0.0im])

To me this is more honest to the user's intentions at the risk of returning too much information which can be discarded easily by the project function being called by the user. So to summarise, I think projecting the function's input's co-tangent leads to information loss and potentially incorrect gradients for some applications.

from zygote.jl.

mcabbott avatar mcabbott commented on June 5, 2024

In a long string of rules, projecting before every rule or after every rule will be broadly similar.

But what you seem to be arguing for is omitting projection at the very first rule. Won't this lead to all kinds of surprises? For example these two functions implement the same thing in slightly different ways... why should the user care if some library changes from one implementation to the other?

julia> Zygote.pullback(x -> real(mydot((3+4im) * x, x)), [1.0, 2.0])[2](1.0)
(ComplexF64[6.0 + 4.0im, 12.0 + 8.0im],)

julia> Zygote.pullback(x -> real((3-4im) * mydot(x, x)), [1.0, 2.0])[2](1.0)
([6.0, 12.0],)

(Using the lower-level pullback here, and above, avoids the fact that gradient applies projection to the final answer. Which is there in case some rules forgot, e.g. certain non-ChainRules ones within Zygote.)

Edit, in fact this example is even stranger, as rule for (3+4im) * x applying projection. With complex input, there is no imaginary part:

julia> Zygote.gradient(x -> real(mydot((3+4im) * x, x)), [1.0, 2.0 .+ 0im])
(ComplexF64[6.0 + 0.0im, 12.0 + 0.0im],)

from zygote.jl.

mohamed82008 avatar mohamed82008 commented on June 5, 2024

I think the example above is an argument against projecting at all in the complex case, in that projecting complex numbers to real numbers leads to loss of information that can sometimes even violate the distributive property. So if there is a single complex number in the chain of calculations it should just propagate backward to the gradient unless users explicitly call Complex somewhere in the chain declaring that the input can only be the real component. So my new suggestion is to not project complex numbers to real numbers at all.

My new suggestion is now

mydot2(x, y) = dot(x, y)

function rrule(::typeof(mydot2), x::AbstractArray, y::AbstractArray)
   out = dot(x, y)
   function dot_pullback(Ω̄)
       ΔΩ = unthunk(Ω̄)
       x̄ = @thunk(reshape(y .* ΔΩ', axes(x)))
       ȳ = @thunk(reshape(x .* ΔΩ, axes(y)))
       return (NoTangent(), x̄, ȳ)
   end
   return out, dot_pullback
end

mymul(x, y) = x * y

function rrule(::typeof(mymul), x::Number, y::Number)
   out = mymul(x, y)
   function mul_pullback(Ω̄)
       ΔΩ = unthunk(Ω̄)
       x̄ = ΔΩ * y
       ȳ = x * ΔΩ
       return (NoTangent(), x̄, ȳ)
   end
   return out, mul_pullback
end

Zygote.pullback(x -> real(mymul((3+4im), mydot2(x, x))), [1.0, 2.0])[2](1.0)
# (ComplexF64[6.0 + 0.0im, 12.0 + 0.0im],)

julia> Zygote.pullback(x -> real(mymul((3+4im), Complex(mydot2(x, x)))), [1.0, 2.0])[2](1.0)
# ([6.0, 12.0],)

So basically calling the constructor function of a type is the primal of the projection operation, or the projection is the pullback of the constructor. This generalises nicely to matrices as well where if I want the gradient of x -> sum(Diagonal(x)) wrt x::Vector, projecting will happen naturally as the pullback of the constructor. This is distinctly different from the case when I want the gradient of sum wrt to Diagonal(x) where the diagonal representation is just a compact matrix representation.

from zygote.jl.

mcabbott avatar mcabbott commented on June 5, 2024

projecting complex numbers to real numbers leads to loss of information that can sometimes even violate the distributive property

Yes it loses information. I'm not sure what distributive property you mean.

But is passing in complex input in order to specify that you consider the domain to be C not R too much to ask?

It just seems extremely strange to me to want these two implementations of the same R -> R function to behave differently (as they did in earlier Zygote):

f1(x::Real) = real(exp(x * (1+im)))
f2(x::Real) = exp(x) * cos(x)
using Plots; plot(f1, -2, 2); plot!(f2, -2, 2)

It should be mentioned that the other kind of standardisation done by the projection machinery is to map structural cotangents to natural where possible. Here, by default x.re has a gradient (re = 1.0, im = nothing) but writing it as Complex(1.0, 0) makes generic rules work:

Zygote.gradient(x -> angle(x.re * x^2), 1+2im)  # (-0.8 + 0.4im,)

from zygote.jl.

mohamed82008 avatar mohamed82008 commented on June 5, 2024

Projecting nothing and structural no tangents is fine. Your example is not surprising to me if they return different types because one function promotes types to the complex domain and the other doesn't. I wonder if anyone actually cares about these subtle cases of returning a complex derivative instead of a real one when they can easily call real on the resulting gradient. The bigger issue is doing the opposite on purpose, returning a real gradient when the user wanted the full complex gradient or returning a diagonal matrix gradient when the user wanted the gradient wrt the full matrix.

from zygote.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.