Giter Club home page Giter Club logo

adam_to_sgd's Introduction

Adam2SGD

This is a modified version of Keras' EarlyStopping callback that switches to the SGD optimizer from Adam following arXiv 1712.07628:

Keskar, N. S., & Socher, R. (2017).
Improving generalization performance by switching from adam to sgd.
arXiv preprint arXiv:1712.07628.
  • The callback monitors learning rate according to (4) from arXiv 1712.07628
  • If condition (4) from that paper is satisfied, the callback stops training early and starts training using separate SWATS function (Switching from Adam To SGD) with SGD optimizer that uses the learning rate that satisfied (4).

Usage:

model = Sequential()
...
model.compile(...)    


AdamToSGD_ = [AdamToSGD(after_training_with_Adam=SWATS(x=train_x, 
                                                       y=train_y,
                                                       ...))]
                                                       
                                                       
def SWATS(momentum=0.0,    # SGD optimizer arguments
          nesterov=False
          ...,
          loss='mse',      # compile arguments
          ...,
          x=None,          # model.fit statements
          y=None,
          **kwargs):
    """
    This user-defined function restarts training if condition 4 from 
    1712.07628 is satisfied in the callback.
    Define optimizer, compile it, and fit model again in one function.
    """
    lr = float(K.get_value(model.optimizer.lr))
    bias_corrected_exponential_avg = lr / (1. - K.get_value(model.optimizer.beta_2))
 
    if (K.abs(bias_corrected_exponential_avg - lr) < 1e-9) is not None:
        return
    else:
        SGD_optimizer = SGD(lr=bia_corrected_exponential_avg,
                            ...)
     
        model.compile(optimizer=SGD_optimizer,
                      ...)
                   
        print('\nNow switching to SGD...\n')
     
        model.fit(x=x,
                  y=y,
                  ...)
  
 
result = model.fit(train_x,
                   train_y,
                   callbacks=[AdamToSGD_, ...],
                   ...)

If condition (4) from arXiv 1712.07628 is satisfied, training will end early and restart with the user-defined SWATS function using the SGD optimizer with the last learning rate value from Adam before that condition.

Tensorflow < 2.0, Keras 2.3.1 or lower.

This callback is more suitable for training with image or text data for hundreds of epochs.

python setup.py install to install.

Update 2021-05-21

If you're having difficulty running this or implementing it in TF 2.0, just train with the Adam optimizer and change the early stopping callback to monitor the learning rate (stop at the LR value from paper). Then manually restart training using the SGD optimizer. All this callback does is automate that whole process.

MIT License

adam_to_sgd's People

Contributors

tr7200 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.