Giter Club home page Giter Club logo

Comments (3)

nmichlo avatar nmichlo commented on June 12, 2024

Hi @dgm2, thank you!

My apologies for the delayed response. Sorry you are right. Most of them are independent.

However, you can have a look at the helper code inside the pytorch lightning utilities, one of the callback classes is specifically for generating latent traversals:

def generate_visualisations(
cls,
trainer_or_dataset: Union[pl.Trainer, DisentDataset],
pl_module: pl.LightningModule,
seed: Optional[int] = 7777,
num_frames: int = 17,
mode: str = 'fitted_gaussian_cycle',
num_stats_samples: int = 64,
# recon_min & recon_max
recon_min: MinMaxHint = None,
recon_max: MinMaxHint = None,
recon_mean: MeanStdHint = None,
recon_std: MeanStdHint = None,
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor, torch.Tensor]]:
# normalize
recon_min, recon_max = get_vis_min_max(
recon_min=recon_min,
recon_max=recon_max,
recon_mean=recon_mean,
recon_std=recon_std,
)
# get dataset and vae framework from trainer and module
dataset, vae = _get_dataset_and_ae_like(trainer_or_dataset, pl_module, unwrap_groundtruth=True)
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ #
# generate traversal
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ #
# get random sample of z_means and z_logvars for computing the range of values for the latent_cycle
with TempNumpySeed(seed):
batch, indices = dataset.dataset_sample_batch(num_stats_samples, mode='input', replace=True, return_indices=True) # replace just in case the dataset it tiny
batch = batch.to(vae.device)
# get representations
if isinstance(vae, Vae):
# variational auto-encoder
ds_posterior, ds_prior = vae.encode_dists(batch)
zs_mean, zs_logvar = ds_posterior.mean, torch.log(ds_posterior.variance)
elif isinstance(vae, Ae):
# auto-encoder
zs_mean = vae.encode(batch)
zs_logvar = torch.ones_like(zs_mean)
else:
log.warning(f'cannot run {cls.__name__}, unsupported type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}')
return
# get min and max if auto
if (recon_min is None) or (recon_max is None):
if recon_min is None: recon_min = float(torch.amin(batch).cpu())
if recon_max is None: recon_max = float(torch.amax(batch).cpu())
log.info(f'auto visualisation min: {recon_min} and max: {recon_max} obtained from {len(batch)} samples')
# produce latent cycle still images & convert them to images
stills = make_decoded_latent_cycles(vae.decode, zs_mean, zs_logvar, mode=mode, num_animations=1, num_frames=num_frames, decoder_device=vae.device)[0]
stills = torch_to_images(stills, in_dims='CHW', out_dims='HWC', in_min=recon_min, in_max=recon_max, always_rgb=True, to_numpy=True)
# generate the video frames and image grid from the stills
# - TODO: this needs to be fixed to not use logvar, but rather the representations or distributions themselves
# - TODO: should this not use `visualize_dataset_traversal`?
frames = make_animated_image_grid(stills, pad=4, border=True, bg_color=None)
image = make_image_grid(stills.reshape(-1, *stills.shape[2:]), num_cols=stills.shape[1], pad=4, border=True, bg_color=None)
# done
return stills, frames, image

This method uses helper functions from disent.util.visualize.vis_img to convert tensors to images, disent.util.visualize.vis_latents to generate latent sequences, and disent.util.visualize.vis_util to combine the images together into a grid or make sequential frames.

The code is more complicated than it need to be for most cases because of some additional handling and quirks. Maybe we can add a specific docs example for latent traversals.

from disent.

dgm2 avatar dgm2 commented on June 12, 2024

thanks!
the callback method returns stills, frames, image
how should I input these into plot_dataset_traversals ? or into visualize_dataset_traversal
or what is are corresponding values there? e.g. does stills corresponds to grid as input into plt_subplots_imshow

e.g. this example makes sense?
many thanks!

trainer = pl.Trainer(
    max_steps=2048,
    gpus=1 if torch.cuda.is_available() else None,
    logger=False,
    checkpoint_callback=False,
    max_epochs=1
)
trainer.fit(module, dataloader)
# trainer.save_checkpoint("trained.ckpt")

viz = VaeLatentCycleLoggingCallback()
stills, frames_, image_ = viz.generate_visualisations(trainer_or_dataset=trainer, pl_module=trainer.lightning_module,
                                                      num_frames=4, num_stats_samples=15)

plt_scale = 4.5
offset = 0.75
factors, frames, _, _, c = stills.shape

plt_subplots_imshow(grid=stills, title=None, row_labels=None, subplot_padding=None,
                    figsize=(offset + (1 / 2.54) * frames * plt_scale, (1 / 2.54) * (factors + 0.45) * plt_scale),
                    show=False)

from disent.

nmichlo avatar nmichlo commented on June 12, 2024

Your example makes sense, but admittedly it has been a while since I last touched the code (I realize the current system is not optimal for these custom scripts, so this will need to be fixed in future).

  • stills is should be an array of shape (num_latents, num_frames, 64, 64, 3) containing individual latent traversals.
  • frames is a concatenated version of stills intended to create videos, so the individual stills over the factors dimension are combined together into an image grid. The final array is approx of shape (num_frames, ~(64 * grid_h), ~(64 * grid_w), 3).
  • image is a single image that you can plot that has all the latent traversals merged together into a grid, the x axis of this grid will correspond to num_latents and y axis to num_frames (or vice versa) so the shape will be approx: (~(64 * num_latents), ~(64 * num_frames), 3)

You can try and plot images directly with plt.imshow(image). Or create your own visualization/animation with the frames or stills

from disent.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.