Comments (11)
Fixed optimiser pickling in 82fa093
TODO: add checkpointing
from disent.
Hi 👋 First of all thanks for the fantastic framework!
I started to work on this issue since I need model saving for my visualization project. Model saving seems to be very simple, however, loading a checkpoint appears to be a bit more complicated as far as I understand. I found that a saved checkpoint can be loaded but the model parameter for e.g. BetaVae
is an omegaconf.dictconfig.DictConfig
instead of an actual AutoEncoder
object if run with hydra
. I think this is because in run.py
the framework is created using hydra
and maybe some magic is happening there. It works fine if I save and load the models with the standard Python API like in your example. Will investigate...
from disent.
Ah, it seems like someone hit the same roadblock before: Lightning-AI/pytorch-lightning#6144
from disent.
Hi @meffmadd, really glad you are finding it useful!
I managed to get away with wandb results a while back so I never got around to fixing this.
I'll investigate based on the information you provided and get back to you. Thank you for that!
from disent.
You noted the object is an OmegaConf instance. I don't think it would break too much if we switch that over to a dictionary, and recursively convert all the values. There is a built in function for this.
(As for pytorch lightning, I have become a bit disillusioned towards it, as it has placed certain constraints on the framework that were never intended.)
To get to my question, how important is API stability for you right now?
from disent.
API stability it is not a major concern for me. I only planned on using the hydra configs but maybe using the Python API is more useful for me (since it removes a layer of complexity). Yeah, frameworks are nice as long as their magic works and there is documentation 😅
For model saving when running with hydra
I think I found a simple workaround in the Vae
class:
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["hyper_parameters"]["model"] = self._model
return super().on_save_checkpoint(checkpoint)
This manually sets the model in the checkpoint and also works for loading! I can implement simple model saving and make a pull request with this if you like so no API changes are necessary.
And a quick unrelated question:
Could you tell me a config for beta-VAE that works well with dSprites? I use bce
loss but this somehow creates NaN values in the encoder output. If I use the norm_conv64 framework it works well but this seems to be not be standard as per your warning.
from disent.
Created pull request #37
from disent.
Thank you so much for the PR! I left a few comments about tests. We just need to make sure to add the new keys to the configs and (possibly) update the tests to tests the checkpointing.
As for your question. That should not be happening with the BCE loss. It may be due to largely unrelated things like the strength of the regularization term, or the learning rate too.
- I know dSprites is a binary dataset, but I would still recommend using MSE loss as the BCE loss has lead to prevalent errors in the VAE field. For example moving from
dsprites
todsprites-imagenet
or evenshapes3d
, where BCE loss is no longer applicable (and then results also don't transfer as well, needing different hparams).
EDIT: on this note another reason for the BCE loss failing could be due to the dataset normalization. I am not sure if that possibly has a part to play, as the output is also normalized. There may be a logic/precision error there.
- You can try disabling this in the dataset config
from disent.
I will fix the configs now!
Thanks for your answer! I will try it with MSE but with a higher learning rate because when I tested the beta-VAE with MSE it did not converge at all.
from disent.
I think possibly a lower beta value then too.
from disent.
Closing this with your changes from:
Thank you for contributing!
Now released under v0.7.0
from disent.
Related Issues (20)
- [BUG]: "AssertionError: Torch not compiled with CUDA enabled" HOT 2
- Windows, Macos & GPU Tests
- Standardise Loss Reduction Modes HOT 1
- Adopt a Lightning-Flash Style API for Frameworks HOT 1
- [Q]: What's the native way to split datasets into train, validation and test? HOT 4
- [Q]: Command to run with defaults? HOT 4
- [FEATURE]: Experiment Config Override Support & Custom Code Registry HOT 3
- [FEATURE]: Allow Metrics To Directly Accept Frameworks & Datasets HOT 1
- [DOCS]: Add Experiment Examples HOT 1
- [BUG]: Investigate Schedule Interaction With Validation & Test Data
- [FEATURE]: Add to Registry from Configs HOT 1
- [FEATURE]: Simplifying your configs using hydra-zen HOT 2
- [BUG]: Automatic downloads not working for MPI3D and dSprites HOT 3
- Averaging in AdaGVAE HOT 2
- [FEATURE]: Add DMS metric HOT 5
- [FEATURE]: Add IRS Metric HOT 2
- [FEATURE]: wandb model checkpoint support
- visualisation with trained model HOT 3
- [BUG]: Frameworks not loading as LightningModule when using trainer.fit
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from disent.