Giter Club home page Giter Club logo

Comments (11)

nmichlo avatar nmichlo commented on June 12, 2024

Fixed optimiser pickling in 82fa093

TODO: add checkpointing

from disent.

meffmadd avatar meffmadd commented on June 12, 2024

Hi 👋 First of all thanks for the fantastic framework!

I started to work on this issue since I need model saving for my visualization project. Model saving seems to be very simple, however, loading a checkpoint appears to be a bit more complicated as far as I understand. I found that a saved checkpoint can be loaded but the model parameter for e.g. BetaVae is an omegaconf.dictconfig.DictConfig instead of an actual AutoEncoder object if run with hydra. I think this is because in run.py the framework is created using hydra and maybe some magic is happening there. It works fine if I save and load the models with the standard Python API like in your example. Will investigate...

from disent.

meffmadd avatar meffmadd commented on June 12, 2024

Ah, it seems like someone hit the same roadblock before: Lightning-AI/pytorch-lightning#6144

from disent.

nmichlo avatar nmichlo commented on June 12, 2024

Hi @meffmadd, really glad you are finding it useful!

I managed to get away with wandb results a while back so I never got around to fixing this.

I'll investigate based on the information you provided and get back to you. Thank you for that!

from disent.

nmichlo avatar nmichlo commented on June 12, 2024

You noted the object is an OmegaConf instance. I don't think it would break too much if we switch that over to a dictionary, and recursively convert all the values. There is a built in function for this.

(As for pytorch lightning, I have become a bit disillusioned towards it, as it has placed certain constraints on the framework that were never intended.)

To get to my question, how important is API stability for you right now?

from disent.

meffmadd avatar meffmadd commented on June 12, 2024

API stability it is not a major concern for me. I only planned on using the hydra configs but maybe using the Python API is more useful for me (since it removes a layer of complexity). Yeah, frameworks are nice as long as their magic works and there is documentation 😅

For model saving when running with hydra I think I found a simple workaround in the Vae class:

    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        checkpoint["hyper_parameters"]["model"] = self._model
        return super().on_save_checkpoint(checkpoint)

This manually sets the model in the checkpoint and also works for loading! I can implement simple model saving and make a pull request with this if you like so no API changes are necessary.

And a quick unrelated question:
Could you tell me a config for beta-VAE that works well with dSprites? I use bce loss but this somehow creates NaN values in the encoder output. If I use the norm_conv64 framework it works well but this seems to be not be standard as per your warning.

from disent.

meffmadd avatar meffmadd commented on June 12, 2024

Created pull request #37

from disent.

nmichlo avatar nmichlo commented on June 12, 2024

Thank you so much for the PR! I left a few comments about tests. We just need to make sure to add the new keys to the configs and (possibly) update the tests to tests the checkpointing.


As for your question. That should not be happening with the BCE loss. It may be due to largely unrelated things like the strength of the regularization term, or the learning rate too.

  • I know dSprites is a binary dataset, but I would still recommend using MSE loss as the BCE loss has lead to prevalent errors in the VAE field. For example moving from dsprites to dsprites-imagenet or even shapes3d, where BCE loss is no longer applicable (and then results also don't transfer as well, needing different hparams).

EDIT: on this note another reason for the BCE loss failing could be due to the dataset normalization. I am not sure if that possibly has a part to play, as the output is also normalized. There may be a logic/precision error there.

  • You can try disabling this in the dataset config

from disent.

meffmadd avatar meffmadd commented on June 12, 2024

I will fix the configs now!

Thanks for your answer! I will try it with MSE but with a higher learning rate because when I tested the beta-VAE with MSE it did not converge at all.

from disent.

nmichlo avatar nmichlo commented on June 12, 2024

I think possibly a lower beta value then too.

from disent.

nmichlo avatar nmichlo commented on June 12, 2024

Closing this with your changes from:

Thank you for contributing!

Now released under v0.7.0

from disent.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.