Giter Club home page Giter Club logo

Comments (6)

owulveryck avatar owulveryck commented on June 19, 2024

For a start, a simple version should be implemented with the axis attribute = 1; a onnx.ErrNotImplemented should be raised if the attribute is something else.

from onnx-go.

owulveryck avatar owulveryck commented on June 19, 2024

WIP in the branch softmax-issue-46

from onnx-go.

owulveryck avatar owulveryck commented on June 19, 2024

The code can be copied from the Softmax operator of Gorgonia instead of using it out-of-the-box.
This could allow implementing softmax for axis != 1.

The commit 7b45c9b is partially implementing the softmax. The trivial test pass:

go test -run=ONNX/TestSoftmaxExample -v
=== RUN   TestONNX
=== RUN   TestONNX/TestSoftmaxExample
--- PASS: TestONNX (0.01s)
    --- PASS: TestONNX/TestSoftmaxExample (0.00s)
PASS

The other tests don't. The error is:

test_structure.go:89: Node Σ[1](%1) :: Matrix float32, has 2 dimensions(Shape: ()). Input shape is (3, 1, 5), which has 3 dimensions

This is probably link to what onnx expect:

Input does not need to explicitly be a 2D vector; rather, it will be coerced into one. For an arbitrary n-dimensional tensor input \in [a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}] and k is the axis provided, then input will be coerced into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. For the default case where axis=1, this means the input tensor will be coerced into a 2D tensor of dimensions [a_0, a_1 * ... * a_{n-1}], where a_0 is often the batch size. In this situation, we must have a_0 = N and a_1 * ... * a_{n-1} = D. Each of these dimensions must be matched correctly, or else the operator will throw errors.

from onnx-go.

owulveryck avatar owulveryck commented on June 19, 2024

A reshape actually does the trick;
softmax is now working; the remaining test is the one with large numbers that leads to a 'NaN' and make the test to fail:

go test -run=ONNX/TestSoftmax -v
=== RUN   TestONNX
=== RUN   TestONNX/TestSoftmaxDefaultAxis
=== RUN   TestONNX/TestSoftmaxAxis1
=== RUN   TestONNX/TestSoftmaxAxis0
=== RUN   TestONNX/TestSoftmaxAxis2
=== RUN   TestONNX/TestSoftmaxExample
=== RUN   TestONNX/TestSoftmaxLargeNumber
--- FAIL: TestONNX (0.01s)
    --- PASS: TestONNX/TestSoftmaxDefaultAxis (0.00s)
    --- PASS: TestONNX/TestSoftmaxAxis1 (0.00s)
    --- PASS: TestONNX/TestSoftmaxAxis0 (0.00s)
    --- PASS: TestONNX/TestSoftmaxAxis2 (0.00s)
    --- PASS: TestONNX/TestSoftmaxExample (0.00s)
    --- FAIL: TestONNX/TestSoftmaxLargeNumber (0.00s)
        test_structure.go:78:
                Error Trace:    test_structure.go:135
                Error:          Expected must not be NaN
                Messages:       the two tensors should be equal.
FAIL
exit status 1
FAIL    github.com/owulveryck/onnx-go/backend/x/gorgonnx        0.032s

from onnx-go.

owulveryck avatar owulveryck commented on June 19, 2024

Using stabilization does not seems to help a lot:

diff --git a/backend/x/gorgonnx/softmax.go b/backend/x/gorgonnx/softmax.go
index 11adc69..604d9af 100644
--- a/backend/x/gorgonnx/softmax.go
+++ b/backend/x/gorgonnx/softmax.go
@@ -22,7 +22,7 @@ func (s *softmax) apply(g *Graph, n *Node) error {
                return err
        }
        a := children[0].gorgoniaNode
-       var reshaped *gorgonia.Node
+       var max, reshaped *gorgonia.Node
        if len(a.Shape()) > 2 {
                if s.axis > len(a.Shape()) {
                        return errors.New("softmax cannot be applied on an axis > len(shape()) of the input")
@@ -43,8 +43,19 @@ func (s *softmax) apply(g *Graph, n *Node) error {
        } else {
                reshaped = a
        }
+       if max, err = gorgonia.Max(reshaped); err != nil {
+               return err
+       }
+       a2, b2, err := gorgonia.Broadcast(reshaped, max, gorgonia.NewBroadcastPattern(nil, []byte{0, 1}))
+       if err != nil {
+               return err
+       }
+       output, err := gorgonia.Sub(a2, b2)
+       if err != nil {
+               return err
+       }
        var exp, sum *gorgonia.Node
-       if exp, err = gorgonia.Exp(reshaped); err == nil {
+       if exp, err = gorgonia.Exp(output); err == nil {
                axis := 1
                if exp.IsScalar() {
                        axis = 0
 go test -run=ONNX/TestSoftmaxLarge -v

=== RUN   TestONNX
=== RUN   TestONNX/TestSoftmaxLargeNumber
--- FAIL: TestONNX (0.01s)
    --- FAIL: TestONNX/TestSoftmaxLargeNumber (0.00s)
        test_structure.go:78: 
                Error Trace:    test_structure.go:135
                Error:          Max difference between +Inf and 0.032058604 allowed is 1e-06, but difference was +Inf
                Messages:       the two tensors should be equal.
FAIL
exit status 1
FAIL    github.com/owulveryck/onnx-go/backend/x/gorgonnx        0.012s

from onnx-go.

owulveryck avatar owulveryck commented on June 19, 2024

implemented by PR #56

from onnx-go.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.