Giter Club home page Giter Club logo

Comments (5)

RomeoV avatar RomeoV commented on June 4, 2024

Any progress on this?
I'm running into a similar problem.

x = rand(3, 4)
@test repeat(x, 1, 2) == @cast _[h, (w, 2)] :=  x[h, w]
# or
@test repeat(x, 1, 2) == @cast _[h, (w, j)] :=  x[h, w] (j in 1:2)

(Both fail).
Perhaps the check for "only on the left" should be removed? I think the first version is even nicer though, and is also supported by python's einops library.

from tensorcast.jl.

mcabbott avatar mcabbott commented on June 4, 2024

Sorry, I haven't yet looked into this. Not so hard, I hope.

from tensorcast.jl.

RomeoV avatar RomeoV commented on June 4, 2024

Great. I'm comparing TensorCast.jl to the recent paper published on einops (https://openreview.net/pdf?id=oapKSVM2bcj).
If this is fixed, all examples in Listing 1 can be done with TensorCast.jl and I can provide a pull request with the appropriate tests / examples.

from tensorcast.jl.

mcabbott avatar mcabbott commented on June 4, 2024

That sounds great.

One or twice I've tried to borrow documentation from einops... I see there's an ancient notebook currently committed, but not linked from anywhere. Would be nice to do more. And more tests never hurt either.

Some investigation into this repeat problem:

julia> R = randn(3, 3*4);  # write in-place, easier:

julia> @cast R[r,(n,c)] = M[r,c]^2  # (n in 1:3)  # n can be inferred
ERROR: LoadError: index n appears only on the left
Stacktrace:
 [1] checkallseen ...
 [2] _macro(exone::Expr, extwo::Nothing, exthree::Nothing; call::TensorCast.CallInfo, dict::Dict{Any, Any})
   @ TensorCast ~/.julia/dev/TensorCast/src/macro.jl:199

# comment out check on line 199 and it works:

julia> @cast R[r,(n,c)] = M[r,c]^2  # (n in 1:3)
3×3×4 Array{Float64, 3}:
[:, :, 1] =
 1.0  1.0  1.0
 4.0  4.0  4.0
...

julia> R  # line above returns the wrong thing, reshaped rather than R itself:
3×12 Matrix{Float64}:
 1.0  1.0  1.0  16.0  16.0  16.0  49.0  49.0  49.0  100.0  100.0  100.0
 4.0  4.0  4.0  25.0  25.0  25.0  64.0  64.0  64.0  121.0  121.0  121.0
 9.0  9.0  9.0  36.0  36.0  36.0  81.0  81.0  81.0  144.0  144.0  144.0

# out-of-place again:

julia> @cast R[r,(n,c)] := M[r,c]^2  (n in 1:3)
ERROR: DimensionMismatch: new dimensions (3, 12) must be consistent with array size 12

julia> @pretty @cast R[r,(n,c)] := M[r,c]^2  (n in 1:3)
begin
    @boundscheck ndims(M) == 2 || throw(ArgumentError("expected a 2-tensor M[r, c]"))
    local (ax_c, ax_n, ax_r) = (axes(M, 2), OneTo(3), axes(M, 1))
    local spider = transmute(M, Val((1, nothing, 2)))
    R = reshape(@__dot__(spider ^ 2), (ax_r, star(ax_n, ax_c)))
end

julia> @cast R[r,(n,c)] := M[r,c]^2 + 0n  (n in 1:3)
3×12 Matrix{Int64}:
 1  1  1  16  16  16  49  49  49  100  100  100
 4  4  4  25  25  25  64  64  64  121  121  121
 9  9  9  36  36  36  81  81  81  144  144  144

To make the new dimension, it needs to do something like .+ 0 .* transpose(ax_n) in the broadcast (as in the last expression). There must be, or have been, logic for this, somewhere...

According to git blame I added this to docs here between 0.4.0 and 0.4.1, but it doesn't run on either of those versions.

Ah now I found a branch: master...repeat

from tensorcast.jl.

mcabbott avatar mcabbott commented on June 4, 2024

Note BTW that @cast _[h, (w, 2)] := x[h, w] will probably never work, but would ideally have a better error.

This is because @cast _[h] := x[h, 1] is already the first column, a constant index. Plausibly @cast _[h, j] := x[h, (j, 1)] (j in 1:2) should then extract some subset of the columns.

The way to specify sizes used to be something like @cast _[h, (w, j:2)] := x[h, w], or else j:2 after the expression, notation chosen to be compact. But this was removed when everything was upgraded to allow offsets everywhere.

from tensorcast.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.