Giter Club home page Giter Club logo

Comments (5)

CyberZHG avatar CyberZHG commented on July 19, 2024

The codes have been tested under 2.0.0-beta1. I'm not able to reproduce the error. 😿

from keras-radam.

virtualdvid avatar virtualdvid commented on July 19, 2024

It seems. It's a keras issue. I'm literally copying and pasting the code in my notebook and calling the class in the compiler. I did the same with Adam from keras_optimizers and got the same error.

I did a work around and put super(RAdam, self).__init__(name='RAdam', **kwargs) and now I'm getting another issue lol:

AttributeError: 'RAdam' object has no attribute 'lr'

from keras-radam.

CyberZHG avatar CyberZHG commented on July 19, 2024

The codes have been updated an hour ago. See #2.

from keras-radam.

virtualdvid avatar virtualdvid commented on July 19, 2024

It's the one I'm using mmm what a weird thing. I'll let you know if I got a solution. Thanks!

from keras-radam.

virtualdvid avatar virtualdvid commented on July 19, 2024

There we go is working now. Here the code if someone got the same issues:

class RAdam(Optimizer):
    """RAdam optimizer.
    # Arguments
        lr: float >= 0. Learning rate.
        beta_1: float, 0 < beta < 1. Generally close to 1.
        beta_2: float, 0 < beta < 1. Generally close to 1.
        epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
        decay: float >= 0. Learning rate decay over each update.
        weight_decay: float >= 0. Weight decay for each param.
    # References
        - [Adam - A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980v8)
        - [On The Variance Of The Adaptive Learning Rate And Beyond](https://arxiv.org/pdf/1908.03265v1.pdf)
    """

    def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
                 epsilon=None, decay=0.0, weight_decay=0.0, **kwargs):
        super(RAdam, self).__init__(name='RAdam', **kwargs)
        with K.name_scope(self.__class__.__name__):
            self._lr = K.variable(lr, name='lr')
            self._iterations = K.variable(0, dtype='int64', name='iterations')
            self._beta_1 = K.variable(beta_1, name='beta_1')
            self._beta_2 = K.variable(beta_2, name='beta_2')
            self._decay = K.variable(decay, name='decay')
            self._weight_decay = K.variable(weight_decay, name='weight_decay')
        if epsilon is None:
            epsilon = K.epsilon()
        self.epsilon = epsilon
        self.initial_decay = decay
        self.initial_weight_decay = weight_decay

    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self._iterations, 1)]
        lr = self._lr
        if self.initial_decay > 0:
            lr = lr * (1. / (1. + self._decay * K.cast(self._iterations, K.dtype(self._decay))))

        t = K.cast(self._iterations, K.floatx()) + 1

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p), name='m_' + str(i)) for (i, p) in enumerate(params)]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p), name='v_' + str(i)) for (i, p) in enumerate(params)]

        self._weights = [self._iterations] + ms + vs

        beta_1_t = K.pow(self._beta_1, t)
        beta_2_t = K.pow(self._beta_2, t)

        sma_inf = 2.0 / (1.0 - self._beta_2) - 1.0
        sma_t = sma_inf - 2.0 * t * beta_2_t / (1.0 - beta_2_t)

        for p, g, m, v in zip(params, grads, ms, vs):
            m_t = (self._beta_1 * m) + (1. - self._beta_1) * g
            v_t = (self._beta_2 * v) + (1. - self._beta_2) * K.square(g)

            m_hat_t = m_t / (1.0 - beta_1_t)
            v_hat_t = K.sqrt(v_t / (1.0 - beta_2_t) + self.epsilon)

            r_t = K.sqrt((sma_t - 4.0) / (sma_inf - 4.0) *
                         (sma_t - 2.0) / (sma_inf - 2.0) *
                         sma_inf / sma_t + self.epsilon)

            p_t = K.switch(sma_t > 5, r_t * m_hat_t / (K.sqrt(v_hat_t + self.epsilon)), m_hat_t)

            if self.initial_weight_decay > 0:
                p_t += self._weight_decay * p

            p_t = p - lr * p_t

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates

    def get_config(self):
        config = {
            'lr': float(K.get_value(self._lr)),
            'beta_1': float(K.get_value(self._beta_1)),
            'beta_2': float(K.get_value(self._beta_2)),
            'decay': float(K.get_value(self._decay)),
            'weight_decay': float(K.get_value(self._weight_decay)),
            'epsilon': self.epsilon,
        }
        base_config = super(RAdam, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

from keras-radam.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.