Giter Club home page Giter Club logo

feuermagier / beyond_deep_ensembles Goto Github PK

View Code? Open in Web Editor NEW
16.0 2.0 2.0 2.71 MB

Code for the paper "Beyond Deep Ensembles: A Large-Scale Evaluation of Bayesian Deep Learning under Distribution Shift"

License: MIT License

Jupyter Notebook 64.68% Python 35.28% Shell 0.04%
bayes-by-backprop bayesian-deep-learning bayesian-inference deep-ensembles natural-gradients out-of-distribution pytorch stein-variational-gradient-descent svgd swag

beyond_deep_ensembles's People

Contributors

feuermagier avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

beyond_deep_ensembles's Issues

Training loss does not converge after reloading checkpoint to continue training

Hi, I'm trying to reload my pre-trained network to continue training in Cifar-10 experiments cifar.py, while the loss does not converge after reloading (the loss still converges if the model is initialized without reloading). I guess maybe it's the issue of setting __base_optimizer as part of the optimizer state so when I run optimizer.load_state_dict(ckpt['optimizer_state_dict']) the state of base optimizer was directly replaced by the state in the optimizer_state_dict.

Now I solve the problem by making base_optimizer as class member of each optimizer of BNN algorithms, such as self.__base_optimizer = base_optimizer instead of self.state["__base_optimizer"] = base_optimizer. To this end, I will load the state_dict of the optimizer and its base optimizer separately and finally, the training loss converges after reloading.

Following is my code snippet for reloading:

def load_model(model_idx, model, scaler, optimizer, out_path, config, log):
    ckpt = None
    start_epoch = 0

    # Load checkpoint and scaler if available
    if config.get("use_checkpoint", None):
        try:
            ckpt_paths = glob.glob(out_path + f"{config['model']}_chkpt_{model_idx}_*.pth")
            ckpt_paths.sort(key=os.path.getmtime)
            ckpt = torch.load(ckpt_paths[-1]) 
            model.load_state_dict(ckpt['model_state_dict'])
            start_epoch = ckpt["epoch"] + 1
            scaler.load_state_dict(ckpt["scaler_state_dict"])
            log.info(f"Loaded checkpoint for model {model_idx} at epoch {start_epoch}")
        except:
            log.info(f"Failed to load checkpoint for model {model_idx}")

    optimizer.init_grad_scaler(scaler)

    # Load optimizer state if available
    # Base optimizer state is loaded separately if available
    if ckpt is not None: 
        try:
            optimizer.load_state_dict(ckpt["optimizer_state_dict"])
            if ckpt.get("base_optimizer") is not None:
                optimizer.get_base_optimizer().load_state_dict(ckpt["base_optimizer"])
            log.info(f"Loaded base optimizer state for model {model_idx}")
        except:
            log.info(f"Failed to load optimizer state for model {model_idx}")

    # Load scheduler state if available
    if config["lr_schedule"]:
        scheduler = wilson_scheduler(optimizer.get_base_optimizer(), config["epochs"], config["lr"], None)
        if ckpt is not None:
            scheduler.load_state_dict(ckpt["scheduler_state_dict"])
            log.info(f"Loaded scheduler state for model {model_idx}")
    else:
        scheduler = None
    
    return start_epoch, model, optimizer, scaler, scheduler

and how I save the model during training:

def save_model(model, optimizer, scheduler, scaler, out_path, config, model_idx, epoch):
    state_dict = {
                    'epoch': epoch,
                    'model_idx': model_idx,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scaler_state_dict': scaler.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else 'None'
                }
    if hasattr(optimizer, "get_base_optimizer"):
        state_dict['base_optimizer'] = optimizer.get_base_optimizer().state_dict()
    torch.save(state_dict, out_path + f"{config['model']}_chkpt_{model_idx}_{epoch}.pth")

an example for the optimizer change is:

class MAPOptimizer(BayesianOptimizer):
    '''
        Maximum A Posteriori

        This simply optimizes a point estimate of the parameters with the given base_optimizer.
    '''

    def __init__(self, params, base_optimizer):
        super().__init__(params, {})
        # self.state["__base_optimizer"] = base_optimizer
        self.__base_optimizer = base_optimizer

Since I'm still looking into other optimizers, it could be a great help if you can inform me of any potential problems of doing so. Thank you very much!

Missing/wrong hparams in cifar.yaml

Hello,

Thanks for the great work. When trying to replicate your experiments I noticed the "update_interval" parameter is missing for CIFAR10's swag optimizer.

There are some other minor typos like e.g. unnecessary "mean_samples" key or "swag_config" should be replaced by "swag" etc. in the same cifar.yaml file. Maybe you could share a more recent version of the cifar.yaml? If not, can you tell me what the "update_interval" parameter should be?

Best,
Emir

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.