Comments (4)
Hi there, sorry for the delayed response!
Unfortunately this is something that I will need to add to the roadmap. I am just not entirely sure myself how to approach this problem when it comes to samplers/metrics that require ground-truth datasets.
-
As soon as you randomly split the dataset into train/test/validate portions it is no longer a ground-truth dataset due to the way the factors are stored, and some of the samplers may no longer function as they assume all the ground-truth factors are still available Additionally this is where the metrics fail too, they require full access to the original datasets.
-
There may be a future workaround for this if new samplers are created that can handle this problem. Although the flexibility of these samplers might be reduced especially for weakly or strongly supervised methods. However, I am still not sure how this problem would be solved with the metrics themselves.
-
I realise the PyTorch lightning frameworks do not yet implement the validation step. I'll add this to the roadmap.
A workaround for your current code may be:
import math
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import Cars3dData
from disent.dataset.transform import ToImgTensorF32
from disent.frameworks.ae import Ae
from disent.metrics import metric_dci
from disent.metrics import metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64
from disent.model.ae import EncoderConv64
# normalise the data
vis_mean = [0.8976676149976628, 0.8891658020067508, 0.885147515814868]
vis_std = [0.22503195531503034, 0.2399461278981261, 0.24792106319684404]
data = Cars3dData(transform=ToImgTensorF32(size=64, mean=vis_mean, std=vis_std))
# SOLUTION:
# -- split the data using built-in functions (no longer ground-truth datasets, but subsets)
data_train, data_val, data_test = torch.utils.data.random_split(data, [
int(math.floor(len(data)*0.6)),
int(math.ceil(len(data)*0.2)),
int(math.ceil(len(data)*0.2)),
])
# -- create multiple disent datasets
dataset_train = DisentDataset(data_train)
dataset_val = DisentDataset(data_val)
dataset_test = DisentDataset(data_test)
# -- create dataloaders
dataloader_train = DataLoader(dataset=dataset_train, batch_size=4, shuffle=True, num_workers=0)
dataloader_val = DataLoader(dataset=dataset_val, batch_size=4, shuffle=True, num_workers=0)
dataloader_test = DataLoader(dataset=dataset_test, batch_size=4, shuffle=True, num_workers=0)
# create the pytorch lightning system
module: pl.LightningModule = Ae(
model=AutoEncoder(
encoder=EncoderConv64(x_shape=(3, 64, 64), z_size=6),
decoder=DecoderConv64(x_shape=(3, 64, 64), z_size=6),
),
cfg=Ae.cfg(
optimizer="adam", optimizer_kwargs=dict(lr=1e-3)
),
)
# PROBLEM: unfortunately the framework does not yet implement the pytorch-lightning validation step
# I'll add this to the roadmap and this should work in future.
trainer = pl.Trainer(max_steps=10000, checkpoint_callback=False, gpus=1 if torch.cuda.is_available() else None)
trainer.fit(module, dataloader_train, dataloader_val)
# PROBLEM: unfortunately the metrics will no longer work with the subsets
# of data. You could instead pass the original full dataset to the
# metrics, but this may be considered an information leak?
# -- This will crash!
get_repr = lambda x: module.encode(x.to(module.device))
metrics = {
**metric_dci(dataset_test, get_repr, num_train=1000, num_test=500, show_progress=True),
**metric_mig(dataset_test, get_repr, num_train=2000),
}
print("metrics:", metrics)
from disent.
This has been fixed in 5695747 release v0.3.4
Frameworks now support basic validation and testing, reusing the code from the training step, however schedules might be broken if these are used.
A new example has been added to the docs: https://github.com/nmichlo/disent/blob/5695747c1e94420c024f1505d9b8a4b3c81ad610/docs/examples/overview_framework_train_val.py
from disent.
It seems not possible to use the train/val/test partition for AdaVAE training. Any way out?
from disent.
@ema-marconato So there are different sampling strategies in the original paper that can be used in different cases.
Unfortunately only the fully random sampling strategies work with the training and validation splits.
from disent.dataset.sampling import RandomSampler
The other strategies need more information and use the ground-truth factor information to enforce certain characteristics:
from disent.dataset.sampling import GroundTruthPairOrigSampler
from disent.dataset.sampling import GroundTruthPairSampler
It is possible that a random sampler could be written that tries to enforce the constraints provided by these ground-truth samplers. Unfortunately these are not implemented.
from disent.
Related Issues (20)
- [BUG]: "AssertionError: Torch not compiled with CUDA enabled" HOT 2
- Windows, Macos & GPU Tests
- Standardise Loss Reduction Modes HOT 1
- Adopt a Lightning-Flash Style API for Frameworks HOT 1
- [Q]: Command to run with defaults? HOT 4
- [FEATURE]: Experiment Config Override Support & Custom Code Registry HOT 3
- [FEATURE]: Allow Metrics To Directly Accept Frameworks & Datasets HOT 1
- [DOCS]: Add Experiment Examples HOT 1
- [BUG]: Investigate Schedule Interaction With Validation & Test Data
- [FEATURE]: Model Saving and Checkpointing HOT 11
- [FEATURE]: Add to Registry from Configs HOT 1
- [FEATURE]: Simplifying your configs using hydra-zen HOT 2
- [BUG]: Automatic downloads not working for MPI3D and dSprites HOT 3
- Averaging in AdaGVAE HOT 2
- [FEATURE]: Add DMS metric HOT 5
- [FEATURE]: Add IRS Metric HOT 2
- [FEATURE]: wandb model checkpoint support
- visualisation with trained model HOT 3
- [BUG]: Frameworks not loading as LightningModule when using trainer.fit
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from disent.