feuermagier / beyond_deep_ensembles Goto Github PK
View Code? Open in Web Editor NEWCode for the paper "Beyond Deep Ensembles: A Large-Scale Evaluation of Bayesian Deep Learning under Distribution Shift"
License: MIT License
Code for the paper "Beyond Deep Ensembles: A Large-Scale Evaluation of Bayesian Deep Learning under Distribution Shift"
License: MIT License
Hi, I'm trying to reload my pre-trained network to continue training in Cifar-10 experiments cifar.py
, while the loss does not converge after reloading (the loss still converges if the model is initialized without reloading). I guess maybe it's the issue of setting __base_optimizer
as part of the optimizer state so when I run optimizer.load_state_dict(ckpt['optimizer_state_dict'])
the state of base optimizer was directly replaced by the state in the optimizer_state_dict
.
Now I solve the problem by making base_optimizer
as class member of each optimizer of BNN algorithms, such as self.__base_optimizer = base_optimizer
instead of self.state["__base_optimizer"] = base_optimizer
. To this end, I will load the state_dict of the optimizer and its base optimizer separately and finally, the training loss converges after reloading.
Following is my code snippet for reloading:
def load_model(model_idx, model, scaler, optimizer, out_path, config, log):
ckpt = None
start_epoch = 0
# Load checkpoint and scaler if available
if config.get("use_checkpoint", None):
try:
ckpt_paths = glob.glob(out_path + f"{config['model']}_chkpt_{model_idx}_*.pth")
ckpt_paths.sort(key=os.path.getmtime)
ckpt = torch.load(ckpt_paths[-1])
model.load_state_dict(ckpt['model_state_dict'])
start_epoch = ckpt["epoch"] + 1
scaler.load_state_dict(ckpt["scaler_state_dict"])
log.info(f"Loaded checkpoint for model {model_idx} at epoch {start_epoch}")
except:
log.info(f"Failed to load checkpoint for model {model_idx}")
optimizer.init_grad_scaler(scaler)
# Load optimizer state if available
# Base optimizer state is loaded separately if available
if ckpt is not None:
try:
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
if ckpt.get("base_optimizer") is not None:
optimizer.get_base_optimizer().load_state_dict(ckpt["base_optimizer"])
log.info(f"Loaded base optimizer state for model {model_idx}")
except:
log.info(f"Failed to load optimizer state for model {model_idx}")
# Load scheduler state if available
if config["lr_schedule"]:
scheduler = wilson_scheduler(optimizer.get_base_optimizer(), config["epochs"], config["lr"], None)
if ckpt is not None:
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
log.info(f"Loaded scheduler state for model {model_idx}")
else:
scheduler = None
return start_epoch, model, optimizer, scaler, scheduler
and how I save the model during training:
def save_model(model, optimizer, scheduler, scaler, out_path, config, model_idx, epoch):
state_dict = {
'epoch': epoch,
'model_idx': model_idx,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else 'None'
}
if hasattr(optimizer, "get_base_optimizer"):
state_dict['base_optimizer'] = optimizer.get_base_optimizer().state_dict()
torch.save(state_dict, out_path + f"{config['model']}_chkpt_{model_idx}_{epoch}.pth")
an example for the optimizer change is:
class MAPOptimizer(BayesianOptimizer):
'''
Maximum A Posteriori
This simply optimizes a point estimate of the parameters with the given base_optimizer.
'''
def __init__(self, params, base_optimizer):
super().__init__(params, {})
# self.state["__base_optimizer"] = base_optimizer
self.__base_optimizer = base_optimizer
Since I'm still looking into other optimizers, it could be a great help if you can inform me of any potential problems of doing so. Thank you very much!
Hello,
Thanks for the great work. When trying to replicate your experiments I noticed the "update_interval" parameter is missing for CIFAR10's swag optimizer.
There are some other minor typos like e.g. unnecessary "mean_samples" key or "swag_config" should be replaced by "swag" etc. in the same cifar.yaml file. Maybe you could share a more recent version of the cifar.yaml? If not, can you tell me what the "update_interval" parameter should be?
Best,
Emir
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.