Hi! Thanks for this interesting work! I just tried the front page example and it turned out not to work for me. Taking the gradient fails with:
julia> gs = gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)[1]
ERROR: Compiling Tuple{NNlibCUDA.var"##cudnnBNForward!#87", Nothing, Float32, Float32, Float32, Bool, Bool, Bool, typeof(NNlibCUDA.cudnnBNForward!), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32}: try/catch is not supported.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/reverse.jl:121
[3] #Primal#19
@ ~/.julia/packages/Zygote/Y6SC4/src/compiler/reverse.jl:202 [inlined]
[4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/reverse.jl:315
[5] _generate_pullback_via_decomposition(T::Type)
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/emit.jl:101
[6] #s3043#1206
@ ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:28 [inlined]
[7] var"#s3043#1206"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:580
[9] _pullback
@ ~/.julia/packages/NNlibCUDA/i1IW9/src/cudnn/batchnorm.jl:48 [inlined]
[10] _pullback(::Zygote.Context, ::NNlibCUDA.var"#cudnnBNForward!##kw", ::NamedTuple{(:eps, :training), Tuple{Float32, Bool}}, ::typeof(NNlibCUDA.cudnnBNForward!), ::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Float32)
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
[11] _pullback (repeats 2 times)
@ ~/.julia/packages/NNlibCUDA/i1IW9/src/cudnn/batchnorm.jl:37 [inlined]
[12] _pullback
@ ~/.julia/packages/NNlibCUDA/i1IW9/src/cudnn/batchnorm.jl:31 [inlined]
[13] _pullback(::Zygote.Context, ::NNlibCUDA.var"##batchnorm#85", ::Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol}, NamedTuple{(:eps, :training), Tuple{Float32, Bool}}}, ::typeof(NNlibCUDA.batchnorm), ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Float32)
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
[14] _pullback
@ ~/.julia/packages/NNlibCUDA/i1IW9/src/cudnn/batchnorm.jl:30 [inlined]
[15] _pullback
@ ~/.julia/packages/Lux/HkXlk/src/layers/normalize.jl:114 [inlined]
[16] _pullback(::Zygote.Context, ::BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::NamedTuple{(:ฮผ, :ฯยฒ, :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}})
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
[17] macro expansion
@ ~/.julia/packages/Lux/HkXlk/src/layers/basic.jl:0 [inlined]
[18] _pullback
@ ~/.julia/packages/Lux/HkXlk/src/layers/basic.jl:330 [inlined]
[19] _pullback(::Zygote.Context, ::typeof(Lux.applychain), ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:ฮผ, :ฯยฒ, :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(:ฮผ, :ฯยฒ, :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
[20] _pullback
@ ~/.julia/packages/Lux/HkXlk/src/layers/basic.jl:328 [inlined]
[21] _pullback(::Zygote.Context, ::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:ฮผ, :ฯยฒ, :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(:ฮผ, :ฯยฒ, :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
[22] _pullback
@ ~/.julia/packages/Lux/HkXlk/src/core.jl:61 [inlined]
[23] _pullback(::Zygote.Context, ::typeof(Lux.apply), ::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:ฮผ, :ฯยฒ, :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(:ฮผ, :ฯยฒ, :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
[24] _pullback
@ ./REPL[10]:1 [inlined]
[25] _pullback(ctx::Zygote.Context, f::var"#1#2", args::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}})
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
[26] _pullback(f::Function, args::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}})
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:34
[27] pullback(f::Function, args::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}})
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:40
[28] gradient(f::Function, args::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:ฮณ, :ฮฒ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}})
@ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:75
[29] top-level scope
@ REPL[10]:1
[30] top-level scope
@ ~/.julia/packages/CUDA/qAl31/src/initialization.jl:52