Giter Club home page Giter Club logo

nmichlo / disent Goto Github PK

View Code? Open in Web Editor NEW
117.0 117.0 19.0 19.21 MB

🧶 Modular VAE disentanglement framework for python built with PyTorch Lightning ▸ Including metrics and datasets ▸ With strongly supervised, weakly supervised and unsupervised methods ▸ Easily configured and run with Hydra config ▸ Inspired by disentanglement_lib

Home Page: https://disent.michlo.dev

License: MIT License

Python 100.00%
autoencoders configurable datasets disentangled-representations disentanglement metric-learning metrics python python3 pytorch pytorch-lightning representation-learning vae

disent's People

Contributors

meffmadd avatar nmichlo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

disent's Issues

Adopt a Lightning-Flash Style API for Frameworks

The current instantiation of Frameworks is terrible, requiring a two callables. One that returns a new optimizer instance and one that returns a new model instance. This is not good for tracking hyper-parameters and overall usability.

Instead opt for a Lightning-Flash style API:

# from: https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/classification/model.py
def __init__(
    self,
    ...
    backbone: Union[str, Tuple[nn.Module, int]] = "resnet18",
    backbone_kwargs: Optional[Dict] = None,
    ...
    optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
    optimizer_kwargs: Optional[Dict[str, Any]] = None,
    scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
    scheduler_kwargs: Optional[Dict[str, Any]] = None,
    ...
    learning_rate: float = 1e-3,
    ...
):

[FEATURE]: Model Saving and Checkpointing

Is your feature request related to a problem? Please describe.
Model saving and checkpointing is currently disabled for experiment/run.py
This was due to old pickling errors and the extensive use of wandb for logging. Actual saved models were not needed at the time.

Describe the solution you'd like
Re-enable model checkpointing, and allow continuing of training.

  • Schedules will need to be pickled
  • Add config options to enable saving and continuing training

Verify InfoVAE

The InfoVAE implementation is probably not correct.

  • loss scaling might not be correctly implemented
  • sane defaults for config
  • irq kernel was removed

About metrics

Hi, great package!

How to understand the num_train and batch_size of the metric_dci or metric_mig? In addition, are there any examples of using the factor indicator?
thank you!

Thank you!

Hi! I just wanted to leave a note to say that I'm doing some work on disentanglement at the moment and this repository has been super helpful (not to mention that I think you provide a really good example of how to use frameworks like hydra and pytorch-lightning).

Not sure if you're planning on taking contributions right now, since you did mention that this work was for your Master Thesis and that's already hard enough as is, but if you want, I can throw a couple of PRs your way.

Improve Documentation

Documentation is missing key framework features

  • augmentations
  • schedules
  • creating your own framework
  • creating your own models
  • creating your own datasets
  • visualisations

[Q]: Command to run with defaults?

Hi there,

What's the simplest way to train a VAE on cars3d with all the default settings? Is there a single command that does this, using all the yaml config files? E.g. cars3d needs to be downsampled to 64x64, and this is specified in its yaml file. Not very familiar with Hydra, so I'm not sure how to automatically load the yaml config files.

Thanks in advance, and for providing this great framework!

[BUG]: Automatic downloads not working for MPI3D and dSprites

Describe the bug
The downloads for MPI3D and dSprites do not work automatically

To Reproduce

from disent.dataset.data import Mpi3dData
data = Mpi3dData(in_memory=True)

Leads to

FileNotFoundError: [Errno 2] No such file or directory: 'C:\\Users\\[...]\\data\\dataset\\mpi3d_realistic\\mpi3d_realistic.npz'

Expected behaviour
It would be nice if the datasets were downloaded automatically, as they are for XYSquares.

Other than that thanks a lot for the library, very helpful!

[DOCS]: Add Experiment Examples

Is your feature request related to a problem? Please describe.
The current examples are very limited and only show how to use disent.

Describe the solution you'd like
Add examples for the experiment/run.py and experiment/config for the new changes for milestone v0.4.0 with the experiment plugin system.

  • Experiment plugins & registry
  • Config overriding

[REWRITE]: Data and Dataset Rewrite

Data and Datasets currently have multiple levels of inheritance

  • eg. XYBlocksData > GroundTruthData > StateSpace
  • eg. GroundTruthDatasetTriples > GroundTruthDataset > (DisentDataset | GroundTruthData > StateSpace)

Data:
Data could be replaced with dependency injection of the StateSpace into the data classes. Data is then a ground-truth data if it provides the property states (or maybe factors).

Datasets:
DisentDataset could be converted to a general class that inherits from torch Dataset, and accepts two initialising arguments: data and sampler. The various existing DisentDataset subclasses should be converted to Sampler classes that return a list of indices for each input index.

Update Configs to Hydra 1.1

Describe the bug
Many features missing in hydra 1.0 were manually implemented in experiments, these features have been added to hydra 1.1

Features added in hydra 1.1 include:

  • recursive defaults
  • recursive instantiation

[BUG]: Verify Models

The models disent.models.ae may have diverged over time from the original implementations.

  • These need to be verified.

Windows, Macos & GPU Tests

Framework is currently only tested on linux cpu. Support should be added for testing the framework on windows & macos, as well as basic GPU tests.

Verify BetaTCVAE

The betatcvae implementation is definitely not correct.

  • loss scaling is not implemented
  • sane defaults for config

[FEATURE]: Allow Metrics To Directly Accept Frameworks & Datasets

Is your feature request related to a problem? Please describe.
Current metrics require that you provide a representation function. This is inconvenient and always repeated. Metrics also always require that the dataset be a DisentDataset.

dataset = DisentDataset(data)
# we always need to produce this same function
get_repr = lambda x: module.encode(x.to(module.device))
# to use with metrics
results = metric_mig(dataset, get_repr)

Describe the solution you'd like
Allow the metrics to directly accept the frameworks instead, and automatically wrap datasets with DisentDataset

# directly use the framework and raw data instead!
results = metric_mig(data, module)

Describe alternatives you've considered

  • Detect if isinstance(obj, DisentFramework) then handle everything automatically!
  • Or add a function to frameworks called get_representation
  • Or detect if an object has the property encode and use that instead!

[FEATURE]: Add to Registry from Configs

Is your feature request related to a problem? Please describe.
Currently need to manually call and add items to disent.registry

Describe the solution you'd like
Add support for experiment/run.py such that it can read from the config to add items to the registry for use.

Standardise Loss Reduction Modes

Removed reduction="sum" in #12, but considerable effort is still required to maintain and verify models (#3, #4, #5)

Two options exist:

  1. Allow models to specify their allowed modes: "mean" or "mean_sum"
  2. Only use "mean_sum"

[Q]: What's the native way to split datasets into train, validation and test?

I'm trying to train a vae on Cars3dData and I was wondering how to split an instance of DisentDataset. Is there a dedicated sampler that does this?

Here is the backbone of what I am trying to run:

import pytorch_lightning as pl
import torch

from torch.utils.data import DataLoader

from disent.dataset import DisentDataset
from disent.dataset.data import Cars3dData
from disent.frameworks.ae import Ae
from disent.metrics import metric_dci, metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.util import is_test_run  # you can ignore and remove this


# prepare the data
data = Cars3dData()
size = 64
vis_mean = [0.8976676149976628, 0.8891658020067508, 0.885147515814868]
vis_std = [0.22503195531503034, 0.2399461278981261, 0.24792106319684404]
dataset_train = DisentDataset(data, transform=ToImgTensorF32(size=64, mean=vis_mean, std=vis_std))
# dataset_val = ?
# dataset_test = ?

dataloader_train = DataLoader(
    dataset=dataset_train,
    batch_size=4,
    shuffle=True,
    num_workers=0,
)

# create the pytorch lightning system
module: pl.LightningModule = Ae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=(3, 64, 64), z_size=6),
        decoder=DecoderConv64(x_shape=(3, 64, 64), z_size=6),
    ),
    cfg=Ae.cfg(
        optimizer="adam", optimizer_kwargs=dict(lr=1e-3), loss_reduction="mean_sum"
    ),
)

# train the model
trainer = pl.Trainer(
    max_steps=10,
    checkpoint_callback=False,
    fast_dev_run=is_test_run(),
    gpus=1 if torch.cuda.is_available() else None,
)
trainer.fit(module, dataloader_train)

# compute disentanglement metrics
# - we cannot guarantee which device the representation is on
# - this will take a while to run
get_repr = lambda x: module.encode(x.to(module.device))

metrics = {
    **metric_dci(
        dataset_train, get_repr, num_train=1000, num_test=500, show_progress=True
    ),
    **metric_mig(dataset_train, get_repr, num_train=2000),
}

# evaluate
print("metrics:", metrics)

Any hints are highly appreciated. Thank you for providing this package!

Best regards
Armin

Remove Framework Callbacks

Pytorch-Lightning modules are supposed to be self-contained units that manage all hyper-parameters?

The frameworks currently take in two arguments that break this requirement.

  • make_optimiser_fn
  • make_model_fn

ARCHITECTURE.md

Add an ARCHITECTURE.md file to help new contributors get started.

[BUG]: Frameworks not loading as LightningModule when using trainer.fit

Describe the bug
When running the QuickStart example for frameworks, I encountered a problem loading the Ae module correctly. Initially, the callback line was giving me trouble but after removing it. I still could not load it as a LightningModule

After running this:

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.frameworks.ae import Ae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.util import is_test_run  # you can ignore and remove this

# prepare the data
data = XYObjectData()
dataset = DisentDataset(data, transform=ToImgTensorF32())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

# create the pytorch lightning system
module: pl.LightningModule = Ae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=6),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
    ),
    cfg=Ae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum')
)


# train the model
trainer = pl.Trainer(logger=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)

I got a TypeError. It happens with any framework I try.

	"name": "TypeError",
	"message": "`model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `Ae`",
	"stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
c:\\Users\\yuviu\\Desktop\\Uni Work\\Thesis\\disent\\experiment\\examples\\example.ipynb Cell 4 line 2
     <a href='vscode-notebook-cell:/c%3A/Users/yuviu/Desktop/Uni%20Work/Thesis/disent/experiment/examples/example.ipynb#W2sZmlsZQ%3D%3D?line=19'>20</a> # train the model
     <a href='vscode-notebook-cell:/c%3A/Users/yuviu/Desktop/Uni%20Work/Thesis/disent/experiment/examples/example.ipynb#W2sZmlsZQ%3D%3D?line=20'>21</a> trainer = pl.Trainer(logger=False, fast_dev_run=is_test_run())
---> <a href='vscode-notebook-cell:/c%3A/Users/yuviu/Desktop/Uni%20Work/Thesis/disent/experiment/examples/example.ipynb#W2sZmlsZQ%3D%3D?line=21'>22</a> trainer.fit(module, dataloader)

File c:\\Users\\yuviu\\anaconda3\\envs\\disent_env\\lib\\site-packages\\pytorch_lightning\\trainer\\trainer.py:529, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    495 def fit(
    496     self,
    497     model: \"pl.LightningModule\",
   (...)
    501     ckpt_path: Optional[str] = None,
    502 ) -> None:
    503     r\"\"\"Runs the full optimization routine.
    504 
    505     Args:
   (...)
    527 
    528     \"\"\"
--> 529     model = _maybe_unwrap_optimized(model)
    530     self.strategy._lightning_module = model
    531     _verify_strategy_supports_compile(model, self.strategy)

File c:\\Users\\yuviu\\anaconda3\\envs\\disent_env\\lib\\site-packages\\pytorch_lightning\\utilities\\compile.py:126, in _maybe_unwrap_optimized(model)
    124 if isinstance(model, pl.LightningModule):
    125     return model
--> 126 raise TypeError(
    127     f\"`model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `{type(model).__qualname__}`\"
    128 )

TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `Ae`"
}

To Reproduce
Steps to reproduce the behaviour.
Just running the quick start examples in the documentation

Expected behaviour
A clear and concise description of what you expected to happen.
The model to start training.

Additional context
I installed the package from source in a conda env. I also tried a new env with v2.0.0 of lightning but still facing this issue.

[BUG]: "AssertionError: Torch not compiled with CUDA enabled"

Hello,

First, thank you for creating this package, it will make my life so much easier. One problem I've run into though is that since installing the package, it only allows me to run my PyTorch models on the CPU, even though when I run torch.device(0), it gives me "cuda". But when I run torch.cuda.is_available(), I get False. I'm on a Windows machine. In another virtual environment, where I don't have disent installed, I can load my models on the GPU just fine. Any idea why this is happening? I'm going to keep playing around with installation of the different packages, and if I figure it out, I'll reply to this post with the solution.

def make_vae(beta):
    return BetaVae(
        make_optimizer_fn=lambda params: Adam(params, lr=5e-3),
        make_model_fn=lambda: AutoEncoder(
            encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
            decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
        ),
        cfg=BetaVae.cfg(beta=beta)
    )

DEVICE = torch.device(0)
model = make_vae(beta=4).to(DEVICE) # Results in AssertionError

[FEATURE]: Experiment Config Override Support & Custom Code Registry

The current experiment configs are intended as examples and defaults.

  • It should be easy to override these in your own projects without actually modifying the source code code of disent or forking the project

The current disent.registry is largely unused and inflexible.

  • It should be easy for an end-user to write their own experiments, custom code and frameworks without modifying disent itself.
  • A basic plugin system should be enabled that allows this functionality and experiment running.

[BUG]: Investigate Schedule Interaction With Validation & Test Data

Describe the bug
Schedules may unintentionally be affected by the use of validation and test data enabled in release v0.3.4 5695747

  • These steps might increment trainer.global_step

Conversely validation and test steps will use the current value of the schedule.

  • This should probably instead be the original value in the config

This issue is in response to the fixes introduced in #22

Expected behaviour
I am not sure what the expected behaviour should be in some cases. Further discussion may be required.

  • The framework may need to check if a schedule wants its value to be used or a default value for the validation and test steps, this functionality should possibly be configurable.

Cleanup Code & Experiments

Much code is left over from my MSc research. This should be cleaned up or removed for the first milestone.

Modular Losses

Losses currently have to be defined at a framework level through the use of overrides.

It would be nice if losses can be built directly, for example:

beta_vae_loss =  4 * KlRegLoss() + MseRecLoss()
# or even
beta_vae_loss =  Param(cfg, 'beta') * KlRegLoss() + MseRecLoss()

# computing the loss for a training step as
loss, logs = beta_vae_loss.compute_loss(ds_posterior, ds_prior, zs_sampled, xs_partial_recon, xs_targ)
dfc_vae_loss =  Param(cfg, 'beta') * KlRegLoss() + 0.5 * MseRecLoss() + 0.5 * DfcRecLoss()

# computing the loss for a training step as
loss, logs = dfc_vae_loss.compute_loss(ds_posterior, ds_prior, zs_sampled, xs_partial_recon, xs_targ)

Alternative data structures should be discussed.

This would require a substantial rewrite, however existing frameworks
could define these loss components in the constructor.

[FEATURE]: Simplifying your configs using hydra-zen

Hello! I just came across your project - it looks really nice!

I noticed that you are using Hydra's yamls to make your experiments configurable. I figured I might put hydra-zen on your radar.

hydra-zen provides simple functions for dynamically creating dataclass-based configs (a.k.a structured configs), and it is especially useful for large-scale projects like yours.

There are a few big "wins" to this:

  • your configs are in Python, so you (and users!) can actually import them, inherit from them, etc. Plus you don't have to hand-write yaml files.
  • your configs are dynamically generated, so if you change your library's API, then your configs will automatically reflect this.
  • hydra-zen will validate configs against the signatures of the classes that it configures, so you can easily include config-validation in your test suite (which is really useful for a project with the size of disent).

For example, this is all you need to generate configs for encoder and decoder models:

import disent
from hydra_zen import builds

EncoderLinearConfig = builds(disent.model.ae.EncoderLinear, populate_full_signature=True)
DecoderLinearLinearConfig = builds(disent.model.ae.DecoderLinearLinear, populate_full_signature=True)

You could then use simple utility functions to inject globally-specified settings (e.g. x_shape='${dataset.meta.x_shape}') into instances of these configs. This really helps keep you from having to repeat yourself.

I hope this is somewhat helpful (or at least interesting). Just wanted to bring this to your attention - feel free to close this issue 😄

[FEATURE]: wandb model checkpoint support

Is your feature request related to a problem? Please describe.
We don't just want to create checkpoints locally, but also upload to w&b

Describe the solution you'd like
should extend the current check pointing system for wandb

Describe alternatives you've considered
N/A

Additional context
Extending #28

visualisation with trained model

Hi, great package!

I am looking at the example in plotting_examples folder.
These seem to work independently from a trained torch model ?
what would be a minimal way / example to use those with a trained model ?
e.g. how to visualise the latent traversal of a trained model

Best regards

Verify DIP-VAE

The DipVAE implementation is probably not correct.

  • loss scaling is not be correctly implemented
  • sane defaults for config

Improve Tests

Tests are currently lacking across disent

  • data
  • datasets
  • schedule
  • metrics
  • transformations

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.