Giter Club home page Giter Club logo

Comments (6)

fchollet avatar fchollet commented on July 18, 2024

The most explicit, idiomatic way would be to create an Initializer subclass.

@keras.saving.register_keras_serializable("my_package")
class RollInitializer(keras.Initializer):
    def __call__(self, shape, dtype=None):
        return keras.ops.roll(keras.ops.arange(shape[-1]), shift=shape[-1] // 2)


@keras.saving.register_keras_serializable("my_package")
class SomeLayer(keras.Layer):
    def build(self, input_shape):
        self.indices = self.add_variable(shape=(input_shape[-1],), initializer=RollInitializer(), trainable=False)

from keras.

james77777778 avatar james77777778 commented on July 18, 2024

We can also use a lambda expression to initialize, which is how quantize is done in Dense:
https://github.com/keras-team/keras/blob/master/keras/src/layers/core/dense.py#L553-L568

In your case:

class SomeLayer(keras.Layer):
    def build(self, input_shape):
        indices = keras.ops.roll(
            keras.ops.arange(input_shape[-1]), shift=input_shape[-1] // 2
        )

        self.indices = self.add_variable(
            shape=input_shape[-1],
            initializer=lambda shape, dtype: indices,  # <-----
            trainable=False,
        )

Note: it is not recommended to use keras.initializers.Constant for arrays because it cannot release the value inside:
https://github.com/keras-team/keras/blob/master/keras/src/initializers/constant_initializers.py#L31

from keras.

LarsKue avatar LarsKue commented on July 18, 2024

Thank you for the suggestions. As a feature request, would it be possible for keras to allow passing the initializing value directly? I think this would be a nice addition to reduce the verbosity of the interface. I.e.,

self.constants = self.add_weight(shape=..., initializer=7)
indices = keras.ops.roll(....)
self.indices = self.add_weight(shape=..., initializer=indices)

to replace

self.zeros = self.add_weight(shape=..., initializer=keras.initializers.Constant(7))
indices = keras.ops.roll(...)
self.indices = self.add_weight(shaep=..., initializer=lambda shape, dtype: indices)

from keras.

LarsKue avatar LarsKue commented on July 18, 2024

I see some possibility for this within initializers.get(initializer). I could whip up a PR if such a feature would be appreciated.

from keras.

grasskin avatar grasskin commented on July 18, 2024

Thank you @LarsKue! Do you think we should add this @fchollet?

from keras.

nkovela1 avatar nkovela1 commented on July 18, 2024

@fchollet Do you have any thoughts on this feature request?

from keras.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.