Giter Club home page Giter Club logo

projected-gan's Introduction

For a quick start, try the Colab:   Projected GAN Quickstart

This repository contains the code for our NeurIPS 2021 paper "Projected GANs Converge Faster"

by Axel Sauer, Kashyap Chitta, Jens Müller, and Andreas Geiger.

If you find our code or paper useful, please cite

@InProceedings{Sauer2021NEURIPS,
  author         = {Axel Sauer and Kashyap Chitta and Jens M{\"{u}}ller and Andreas Geiger},
  title          = {Projected GANs Converge Faster},
  booktitle      = {Advances in Neural Information Processing Systems (NeurIPS)},
  year           = {2021},
}

Related Projects

ToDos

Requirements

  • 64-bit Python 3.8 and PyTorch 1.9.0 (or later). See https://pytorch.org for PyTorch install instructions.
  • Use the following commands with Miniconda3 to create and activate your PG Python environment:
    • conda env create -f environment.yml
    • conda activate pg
  • The StyleGAN2 generator relies on custom CUDA kernels, which are compiled on the fly. Hence you need:
    • CUDA toolkit 11.1 or later.
    • GCC 7 or later compilers. Recommended GCC version depends on CUDA version, see for example CUDA 11.4 system requirements.
    • If you run into problems when setting up for the custom CUDA kernels, we refer to the Troubleshooting docs of the original StyleGAN repo. When using the FastGAN generator you will not need the custom kernels.

Data Preparation

For a quick start, you can download the few-shot datasets provided by the authors of FastGAN. You can download them here. To prepare the dataset at the respective resolution, run for example

python dataset_tool.py --source=./data/pokemon --dest=./data/pokemon256.zip \
  --resolution=256x256 --transform=center-crop

You can get the datasets we used in our paper at their respective websites:

CLEVR, FFHQ, Cityscapes, LSUN, AFHQ, Landscape.

Training

Training your own PG on LSUN church using 8 GPUs:

python train.py --outdir=./training-runs/ --cfg=fastgan --data=./data/pokemon256.zip \
  --gpus=8 --batch=64 --mirror=1 --snap=50 --batch-gpu=8 --kimg=10000

--batch specifies the overall batch size, --batch-gpu specifies the batch size per GPU. If you use fewer GPUs, the training loop will automatically accumulate gradients, until the overall batch size is reached.

If you want to use the StyleGAN2 generator, pass --cfg=stylegan2. We also added a lightweight version of FastGAN (--cfg=fastgan_lite). This backbone trains fast regarding wallclock time and yields better results on small datasets like Pokemon. Samples and metrics are saved in outdir. To monitor the training progress, you can inspect fid50k_full.json or run tensorboard in training-runs.

Generating Samples & Interpolations

To generate samples and interpolation videos, run

python gen_images.py --outdir=out --trunc=1.0 --seeds=10-15 \
  --network=PATH_TO_NETWORK_PKL

and

python gen_video.py --output=lerp.mp4 --trunc=1.0 --seeds=0-31 --grid=4x2 \
  --network=PATH_TO_NETWORK_PKL

We provide the following pretrained models (pass the url as PATH_TO_NETWORK_PKL):

https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/art_painting.pkl
https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/church.pkl
https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/bedroom.pkl
https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/cityscapes.pkl
https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/clevr.pkl
https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/ffhq.pkl
https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/flowers.pkl
https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/landscape.pkl
https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/pokemon.pkl

Quality Metrics

Per default, train.py tracks FID50k during training. To calculate metrics for a specific network snapshot, run

python calc_metrics.py --metrics=fid50k_full --network=PATH_TO_NETWORK_PKL

To see the available metrics, run

python calc_metrics.py --help

Using PG in your own project

Our implementation is modular, so it is straightforward to use PG in your own codebase. Simply copy the pg_modules folder to your project. Then, to get the projected multi-scale discriminator, run

from pg_modules.discriminator import ProjectedDiscriminator
D = ProjectedDiscriminator()

The only thing you still need to do is to make sure that the feature network is not trained, i.e., explicitly set

D.feature_network.requires_grad_(False)

in your training loop.

Acknowledgments

Our codebase build and extends the awesome StyleGAN2-ADA repo and StyleGAN3 repo, both by Karras et al.

Furthermore, we use parts of the code of FastGAN and MiDas.

projected-gan's People

Contributors

adambielski avatar ak391 avatar miskinis avatar xl-sr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

projected-gan's Issues

Tips on Small Complex Datasets

Hi, I'm very impressed with the results of this paper and also the insightful approach to gain a significant boost in computational efficiency.

Right now I'm testing the model with a custom dataset of humans in various poses, families, and people in general, and I noticed that the textures, the colors, and the image overall is really good compared with other models, also, it trains in 1/10 of the time. But, the generated faces don't look as good as the other aspects of the image. Here is an example of a generated grid at kimg 200:

image

My question is: How can I improve the results, especially on the faces?

Currently, I'm using the FastGAN backbone because the dataset is around 2100 images of 256x256, 1 GPU, mirror=1, and the other parameters with default values.

multi-label generation

How can I go about changing the code such that it generates images such that it's conditions on multiple labels? Like for example, the image must contain, bear,apple,orange ...etc etc... you get the idea.

Out of memory with batch at 1

Hi, I got an issue when training on a custom datatset, with fast gan I got a RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 8.00 GiB total capacity; 5.92 GiB already allocated; 0 bytes free; 6.06 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF when training.
I reduced the batch size to one but I keep getting this error
the command :
python train.py --outdir=./output_models/ --cfg=fastgan --data=./mydata --gpus 1 --batch 1 --mirror 0 --snap 50 --batch-gpu=1 --kimg=50 --cond True

the dataset is composed with 50k image at 256x256px.

I'm on windows 10 with a RTX 2070 (8 GB VRAM)

Hiw can I prevent that ?

Thanks in advance

Creating video issue

When using the default set up I am able to create a video without any problem. The issue I am having is when locating the .pkl file in my GDrive. The GDrive is connected, the training is being saved there and I am able to resume training with no problem. However when I use the path to the .pkl file neither images or video can be created. Do you have any ideas why this may be happening and any solutions?

Thank you for any help you can give, I greatly appreciate it.

ugly code

Your discriminator code hurts my eyes and my brain. How is anyone supposed to remake that into tensorflow?

Exporting trained model to onnx

Hi, I tried to implement a onnx export of trained model. Following Pytorch doc, I wrote:

network_pkl = "model.pkl"

print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
model = G

# Labels.
label = torch.zeros([1, G.c_dim], device=device)

# Export the model
torch.onnx.export(model,                            # model being run
                    label,                          # model input (or a tuple for multiple inputs)
                    "export.onnx",                  # where to save the model (can be a file or file-like object)
                    export_params=True,             # store the trained parameter weights inside the model file
                    opset_version=9,                # the ONNX version to export the model to
                    do_constant_folding=True,       # whether to execute constant folding for optimization
                    input_names = ['Latent'],       # the model's input names
                    output_names = ['Image']        # the model's output names
                    )
 

And I got :
TypeError: forward() missing 1 required positional argument: 'c'

Can't figure out where the issues is.

Thanks in advance

Note:

The model trained is a fast gan lite conditional

FastGAN grid artifacts

I've been noticing quite a lot of griddy/repetitive patterns in the outputs when training at high resolution with FastGAN.

Will the change from today help address those? Or are these inherent to the skip-excitation layers? (the grids do seem to be ~32x32, which is what is skipped to the 512x512 block). Alternatively, would you happen to know ways that these patterns could be reduced?

Example training grid with repetitive grid patterns (5000 image dataset after 919 kimg):

training grid with repetitive grid patterns

Example training grid with repetitive grid patterns and mode collapse (4000 image dataset after 680 kimg finetuning from above checkpoint)

training grid with repetitive grid patterns (and mode collapse)

Installation without conda

Hi Axel,

Thank you for sharing the code for your article. I am trying to install and run the code on a HPC cluster where using conda is not adviced.

Could you release a requirements.txt file, so that I can install the dependencies with pip? I am also happy to hear any other suggestions you might have.

Thanks in advance!

Feature Request & a few questions

Please add a commandline parameter that a new pkl file is saved per snapshot, as is the case in normal stylegan.

For what are the parameters cbase cmax map-depth? I can enter fantasy values, nothing changes.

I can train kimgs as long as I want with batch-gpu=4, on resume however I ALWAYS get a cuda out of memory? With batch-gpu=2 the resume works then. Why is that so?

Is it possible to train Projected GAN with TPU, if so do you implement this?

I noticed, I specify the learning rate with --dlr=0.0025, but in the training_options.json or directly at the start training it says "lr": 0.0002 is this a bug?

Are these extreme differences normal compared to stylegan3?
Same dataset with 1 million images, same state of training Kimgs, Projected GAN has 70 million parameters more than the stylegan model, but it looks funny.

I simply added the Generator & Discriminator parameters displayed during training.

Projected GAN Model has: 119.351.119
Stylegan3-T Model has: 47.321.532

Shouldn't Projected GAN be much better?

I'm not a programmer, maybe I'm just misunderstanding something.

Left PG vs Right SG3

PGvsSG3-GPSF-3x3-trunc07-seed0-125-lowquality_vid_H.264.mp4

I also noticed that Projected GAN often just produces empty boxes (in the fake images) or really strange patterns compared to Stylegan3.

for example something like this:
wtf_empty
pg muster

Collecting environment information...
PyTorch version: 1.10.1
CUDA used to build PyTorch: 11.3
OS: Microsoft Windows 10 Pro

Python version: 3.8.12 (default, Oct 12 2021, 03:01:40) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19043-SP0
Is CUDA available: True
CUDA runtime version: 11.4.120
GPU models and configuration: GPU 0: RTX 2070 Super
Nvidia driver version: 511.09
cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\cudnn_ops_train64_8.dll

training_options.json.txt

Question w.r.t SEblock.

Hello.

First, thanks for this very cool project.

I have a question w.r.t SEblock.

Do you have any specific reason why bias option is set to "False" in conv layer of SEblock?

I know that the FastGAN and MobileNet also have the same option.

Thank you.

Training on LSUN bedroom dataset.

First, thank you for sharing great work!

I trained Projected FastGAN in the LSUN bedroom.

However, FID is increasing across training iterations.

The FID and training loss trends are as follows:

temp.

Is there anyone who suffers from the same issue?

Thank you!

stylegan2 produces color splats

I'm trying to run stylegan2 configuration, but it produces almost random color splats. What can it be?
Same results appeared on google collar pro (p100)
and on paper space gradient (nvcr.io/nvidia/pytorch:21.10-py3 docker + quadro m4000)

fakes000064
fakes000068

Mode Collapse on a New Dataset

Hi,

I'm getting Mode Collapse on a new dataset. I do see nice behavior on AFHQ. Which HyperParameters should I try to optimize in order to get it working on my dataset?

Thanks in advance @xl-sr

Problem running Stylegan2 in collar

When trying to run stylegen2 using cfg='stylegan2' this error occurs:

in train(**kwargs)
76
77 # Launch.
---> 78 launch_training(c=c, desc=desc, outdir=opts.outdir)

in launch_training(c, desc, outdir, rank)
43 sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None
44 training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
---> 45 training_loop.training_loop(rank=rank, **c)

/content/projected_gan/training/training_loop.py in training_loop(run_dir, training_set_kwargs, data_loader_kwargs, G_kwargs, D_kwargs, G_opt_kwargs, D_opt_kwargs, loss_kwargs, metrics, random_seed, num_gpus, rank, batch_size, batch_gpu, ema_kimg, ema_rampup, G_reg_interval, D_reg_interval, total_kimg, kimg_per_tick, image_snapshot_ticks, network_snapshot_ticks, resume_pkl, resume_kimg, cudnn_benchmark, abort_fn, progress_fn, restart_every)
188 z = torch.empty([batch_gpu, G.z_dim], device=device)
189 c = torch.empty([batch_gpu, G.c_dim], device=device)
--> 190 img = misc.print_module_summary(G, [z, c])
191 misc.print_module_summary(D, [img, c])
192

/content/projected_gan/torch_utils/misc.py in print_module_summary(module, inputs, max_nesting, skip_redundant)
214
215 # Run module.
--> 216 outputs = module(*inputs)
217 for hook in hooks:
218 hook.remove()

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1118 input = bw_hook.setup_input_hook(input)
1119
-> 1120 result = forward_call(input, **kwargs)
1121 if _global_forward_hooks or self._forward_hooks:
1122 for hook in (
_global_forward_hooks.values(), *self._forward_hooks.values()):

/content/projected_gan/pg_modules/networks_stylegan2.py in forward(self, z, c, truncation_psi, truncation_cutoff, update_emas, **synthesis_kwargs)
533
534 def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
--> 535 ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
536 img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
537 return img

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1118 input = bw_hook.setup_input_hook(input)
1119
-> 1120 result = forward_call(input, **kwargs)
1121 if _global_forward_hooks or self._forward_hooks:
1122 for hook in (
_global_forward_hooks.values(), *self._forward_hooks.values()):

/content/projected_gan/pg_modules/networks_stylegan2.py in forward(self, z, c, truncation_psi, truncation_cutoff, update_emas)
236 for idx in range(self.num_layers):
237 layer = getattr(self, f'fc{idx}')
--> 238 x = layer(x)
239
240 # Update moving average of W.

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1118 input = bw_hook.setup_input_hook(input)
1119
-> 1120 result = forward_call(input, **kwargs)
1121 if _global_forward_hooks or self._forward_hooks:
1122 for hook in (
_global_forward_hooks.values(), *self._forward_hooks.values()):

/content/projected_gan/pg_modules/networks_stylegan2.py in forward(self, x)
116 else:
117 x = x.matmul(w.t())
--> 118 x = bias_act.bias_act(x, b, act=self.activation)
119 return x
120

/content/projected_gan/torch_utils/ops/bias_act.py in bias_act(x, b, dim, act, alpha, gain, clamp, impl)
82 assert isinstance(x, torch.Tensor)
83 assert impl in ['ref', 'cuda']
---> 84 if impl == 'cuda' and x.device.type == 'cuda' and _init():
85 return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
86 return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)

/content/projected_gan/torch_utils/ops/bias_act.py in _init()
44 headers=['bias_act.h'],
45 source_dir=os.path.dirname(file),
---> 46 extra_cuda_cflags=['--use_fast_math'],
47 )
48 return True

/content/projected_gan/torch_utils/custom_ops.py in get_plugin(module_name, sources, headers, source_dir, **build_kwargs)
135 cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
136 torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
--> 137 verbose=verbose_build, sources=cached_sources, **build_kwargs)
138 else:
139 torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)

/usr/local/lib/python3.7/dist-packages/torch/utils/cpp_extension.py in load(name, sources, extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths, build_directory, verbose, with_cuda, is_python_module, is_standalone, keep_intermediates)
1134 is_python_module,
1135 is_standalone,
-> 1136 keep_intermediates=keep_intermediates)
1137
1138

/usr/local/lib/python3.7/dist-packages/torch/utils/cpp_extension.py in _jit_compile(name, sources, extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths, build_directory, verbose, with_cuda, is_python_module, is_standalone, keep_intermediates)
1345 verbose=verbose,
1346 with_cuda=with_cuda,
-> 1347 is_standalone=is_standalone)
1348 finally:
1349 baton.release()

/usr/local/lib/python3.7/dist-packages/torch/utils/cpp_extension.py in _write_ninja_file_and_build_library(name, sources, extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths, build_directory, verbose, with_cuda, is_standalone)
1416 with_cuda: Optional[bool],
1417 is_standalone: bool = False) -> None:
-> 1418 verify_ninja_availability()
1419 if IS_WINDOWS:
1420 compiler = os.environ.get('CXX', 'cl')

/usr/local/lib/python3.7/dist-packages/torch/utils/cpp_extension.py in verify_ninja_availability()
1472 '''
1473 if not is_ninja_available():
-> 1474 raise RuntimeError("Ninja is required to load C++ extensions")
1475
1476

RuntimeError: Ninja is required to load C++ extensions

Confused about the image blur strategy used in the discriminator

This is a very nice paper!
I have some naive questions:

  1. what is the purpose of image blur and does it have to be added? If it is optional, how much does it affect the result?
  2. what is the function of "self.interp224"? Different resolutions (e.g. 512) need to be interpolated to 224?

Looking forward to your reply!

envirnment.yml won't build with Anaconda

(base) PS C:\projected_gan> conda env create -f environment.yml
Collecting package metadata (repodata.json): done
Solving environment: failed

ResolvePackageNotFound:
  - cudnn==8.2.1.32=h86fa8c9_0
  - brotli==1.0.9=he6710b0_2
  - libgomp==11.2.0=h1d223b6_11
  - libpng==1.6.37=hbc83047_0
  - numpy-base==1.21.2=py39h79a1101_0
  - libblas==3.9.0=12_linux64_mkl
  - _libgcc_mutex==0.1=conda_forge
  - expat==2.4.1=h2531618_2
  - openjpeg==2.4.0=h3ad879b_0
  - libtiff==4.2.0=h85742a9_0
  - zstd==1.4.9=haebb681_0
  - libuv==1.40.0=h7b6447c_0
  - libgfortran4==7.5.0=ha8ba4b0_17
  - icu==58.2=he6710b0_3
  - psutil==5.8.0=py39h3811e60_1
  - protobuf==3.18.0=py39he80948d_0
  - tk==8.6.11=h1ccaba5_0
  - torchvision==0.10.1=py39cuda111hcd06603_0_cuda
  - pip==21.2.4=py39h06a4308_0
  - libwebp-base==1.2.0=h27cfd23_0
  - pyqt==5.9.2=py39h2531618_6
  - pcre==8.45=h295c915_0
  - mkl-service==2.4.0=py39h7f8727e_0
  - freetype==2.11.0=h70c0345_0
  - libxml2==2.9.12=h03d6c58_0
  - setuptools==58.0.4=py39h06a4308_0
  - python==3.9.7=h12debd9_1
  - cryptography==35.0.0=py39hd23ed53_0
  - fontconfig==2.13.1=h6c09931_0
  - ninja==1.10.2=py39hd09550d_3
  - cycler==0.10.0=py39h06a4308_0
  - importlib-metadata==4.8.2=py39hf3d152e_0
  - libuuid==1.0.3=h7f8727e_2
  - magma==2.5.4=ha9b7cf9_2
  - brotlipy==0.7.0=py39h27cfd23_1003
  - aiohttp==3.7.0=py39h07f9747_0
  - nccl==2.11.4.1=h97a9cb7_0
  - intel-openmp==2021.4.0=h06a4308_3561
  - scipy==1.7.1=py39h292c36d_2
  - _openmp_mutex==4.5=1_gnu
  - libffi==3.3=he6710b0_2
  - sleef==3.5.1=h9b69904_2
  - multidict==5.2.0=py39h3811e60_1
  - certifi==2021.10.8=py39hf3d152e_1
  - zlib==1.2.11=h7b6447c_3
  - numpy==1.21.2=py39h20f2e39_0
  - ca-certificates==2021.10.8=ha878542_0
  - chardet==3.0.4=py39h079e4ff_1008
  - libprotobuf==3.18.0=h780b84a_1
  - mkl_random==1.2.2=py39h51133e4_0
  - gst-plugins-base==1.14.0=h8213a91_2
  - future==0.18.2=py39hf3d152e_4
  - glib==2.69.1=h5202010_0
  - libstdcxx-ng==11.2.0=he4da1e4_11
  - pysocks==1.7.1=py39h06a4308_0
  - liblapack==3.9.0=12_linux64_mkl
  - tensorboard-data-server==0.6.0=py39h95dcef6_1
  - readline==8.1=h27cfd23_0
  - grpcio==1.38.1=py39hff7568b_0
  - libxcb==1.14=h7b6447c_0
  - kiwisolver==1.3.1=py39h2531618_0
  - qt==5.9.7=h5867ecd_1
  - yarl==1.7.2=py39h3811e60_1
  - ld_impl_linux-64==2.35.1=h7274673_9
  - pytorch-gpu==1.9.1=cuda111py39h788eb59_3
  - mkl==2021.4.0=h06a4308_640
  - matplotlib==3.4.2=py39h06a4308_0
  - pytorch==1.9.1=cuda111py39hb4a4491_3
  - matplotlib-base==3.4.2=py39hab158f2_0
  - libgfortran-ng==7.5.0=ha8ba4b0_17
  - cudatoolkit==11.1.74=h6bb024c_0
  - lcms2==2.12=h3be6417_0
  - c-ares==1.18.1=h7f98852_0
  - dbus==1.13.18=hb2f20db_0
  - sip==4.19.13=py39h2531618_0
  - jpeg==9d=h7f8727e_0
  - openssl==1.1.1l=h7f98852_0
  - libgcc-ng==11.2.0=h1d223b6_11
  - cffi==1.14.6=py39h400218f_0
  - gstreamer==1.14.0=h28cd5cc_2
  - mkl_fft==1.3.1=py39hd3c417c_0
  - sqlite==3.36.0=hc218d9a_0
  - xz==5.2.5=h7b6447c_0
  - tornado==6.1=py39h27cfd23_0
  - lz4-c==1.9.3=h295c915_1
  - ncurses==6.3=h7f8727e_2
  - pillow==8.3.1=py39h2c7a002_0

This is using Anaconda 4.10 which I believe is the latest version. There shouldn't be any difference in Anaconda vs. Miniconda?

Maybe these packages don't exist for Windows.

gen_images.py & gen_video.py issues

Hi, I'm getting the following issues when trying to generate images or videos, if anyone has any ideas?

python gen_images.py --outdir=out --trunc=1.0 --seeds=4096-4105 --network=training-runs/00000-fastgan-faces_512-gpus1-batch64/network-snapshot.pkl

Results in:

Loading networks from "training-runs/00000-fastgan-faces_512-gpus1-batch64/network-snapshot.pkl"...
Generating image for seed 4096 (0/10) ...
Traceback (most recent call last):
  File "/home/nerdy/github/projected_gan/gen_images.py", line 143, in <module>
    generate_images() # pylint: disable=no-value-for-parameter
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "/home/nerdy/github/projected_gan/gen_images.py", line 135, in generate_images
    img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nerdy/github/projected_gan/pg_modules/networks_fastgan.py", line 173, in forward
    img = self.synthesis(w, c)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nerdy/github/projected_gan/pg_modules/networks_fastgan.py", line 60, in forward
    feat_4 = self.init(input)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nerdy/github/projected_gan/pg_modules/blocks.py", line 65, in forward
    return self.init(noise)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 916, in forward
    return F.conv_transpose2d(
RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same

With gen_video:

python gen_video.py --output=vid.mp4 --seeds=4096-4105 --network=training-runs/00000-fastgan-faces_512-gpus1-batch64/network-snapshot.pkl

Traceback (most recent call last):
  File "/home/nerdy/github/projected_gan/gen_video.py", line 190, in <module>
    generate_images() # pylint: disable=no-value-for-parameter
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "/home/nerdy/github/projected_gan/gen_video.py", line 185, in generate_images
    gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, class_idx=class_idx)
  File "/home/nerdy/github/projected_gan/gen_video.py", line 74, in gen_interp_video
    ws = G.mapping(z=zs, c=label, truncation_psi=psi)
  File "/home/nerdy/anaconda3/envs/pg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'truncation_psi

CSM uses 1x1 conv, or 3x3?

Hi, thanks for sharing your great work!

I found the current code is using 1x1 conv in FeatureFusionBlock (which is used for CSM), but it is a bit confusing to me as the paper states that CSM uses 3x3 conv. Could you clarify on this?

Network snapshot saving in Colab not working

Hey guys!

Last night I got to try out your colab version of the project and started training with my own dataset. The results are amazing and by far outperform my results with stylegan2 (at a fraction of training time) so kudos to you!
However, while the network properly saves sample images to my google drive folder every 10 ticks, it does not save any network pickles and I am at a loss as to why it doesn't.

Do you have any hints for me?

Colab not working on

Hey,
I'm trying to reproduce your results right now and I'm playing around with your notebook. You are using conda for managing the dependencies and this is not represented in it. timm and dill don't seem to be enough.
Colab fails with missing imports.

Any plans to support Google Colab completely? (or a way to use the env yml to install them)

network pkls should be versioned

Took me awhile to figure out that the network pkl is overwritten. I think it should be versioned (checkpoint) like with all other implementations. Especially where GANs are horrible to train, you never know 100% which was the right point.

pretrained models

Hi,

I am trying to generate samples with this method and reproduce the FID numbers reported in the paper.
I was wondering if you could provide the pre-trained models that you use in the paper?

Regards,
Gaurav

Huggingface Spaces

Hello, would you be interested in sharing a web demo on Huggingface Spaces for Projected GAN?

It would make this model more accessible as it would allow people to try out the model directly from the browser. Some other recent machine learning model repos have set up Spaces for easy access:

github: https://github.com/salesforce/BLIP
Spaces: https://huggingface.co/spaces/akhaliq/BLIP

github: https://github.com/facebookresearch/omnivore
Spaces: https://huggingface.co/spaces/akhaliq/omnivore

Spaces is completely free, and I can help setup a Gradio Space. Here are some getting started instructions if you'd prefer to do it yourself: https://huggingface.co/blog/gradio-spaces

Config for 11GB GPU

Hi. This is one of my most anticipated project recently and thanks for finally opensourcing it. It seems like the current training config is made for 16GB GPUs so I encountered OOM on 11GB ones. Could you please suggest a config that is compatible for smaller GPUs (11GB in my case)? Thanks a lot!

Issue replicating published numbers

Hi,

I'm having some trouble replicating the published numbers on the Pokemon set. My runs converge to an FID of about 29. While some variance from published numbers is expected, this is quite a large difference. I'm not sure why this is happening, as I've not changed the code or training parameters. Could you kindly provide some insights?

Thanks

MultiGPU errors

Thanks for the great repo and the congrats on the amazing paper!
Training is working fine for me with 1 GPU but I'm getting mysterious errors with >1 GPU. I'm not sure how to fix these. I'm using your recommended conda environment with CUDA 11.3 . When running the pokemon expt with 2 GPUs I get the following std out:

Distributing across 2 GPUs...                                                                                                                                                               
Setting up training phases...                                                                                                                                                               
Exporting sample images...                                                                                                                                                                  
Traceback (most recent call last):                                                                                                                                                          
  File "/home/ed/Documents/repos/projected_gan/train.py", line 266, in <module>                                                                                                             
    main() # pylint: disable=no-value-for-parameter                                                                                                                                         
  File "/home/ed/.conda/envs/pg/lib/python3.9/site-packages/click/core.py", line 1137, in __call__                                                                                          
    return self.main(*args, **kwargs)                                                                                                                                                       
  File "/home/ed/.conda/envs/pg/lib/python3.9/site-packages/click/core.py", line 1062, in main                                                                                              
    rv = self.invoke(ctx)                                                                                                                                                                   
  File "/home/ed/.conda/envs/pg/lib/python3.9/site-packages/click/core.py", line 1404, in invoke                                                                                            
    return ctx.invoke(self.callback, **ctx.params)                                                                                                                                          
  File "/home/ed/.conda/envs/pg/lib/python3.9/site-packages/click/core.py", line 763, in invoke                                                                                             
    return __callback(*args, **kwargs)                                                                                                                                                      
  File "/home/ed/Documents/repos/projected_gan/train.py", line 252, in main                                                                                                                 
    launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)                                                                                                               
  File "/home/ed/Documents/repos/projected_gan/train.py", line 103, in launch_training                                                                                                      
    torch.multiprocessing.spawn(fn=subprocess_fn, args=(c, temp_dir), nprocs=c.num_gpus)                                                                                                    
  File "/home/ed/.conda/envs/pg/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn                                                                             
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')                                                                                                            
  File "/home/ed/.conda/envs/pg/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes                                                                   
    while not context.join():                                                                                                                                                               
  File "/home/ed/.conda/envs/pg/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 130, in join                                                                              
    raise ProcessExitedException(                                                                                                                                                           
torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with signal SIGABRT 

And the following error out:

terminate called after throwing an instance of 'c10::CUDAError'                                                                                  
  what():  CUDA error: the launch timed out and was terminated                                                                                                                       
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.                                                           
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.                                                                                      
Exception raised from create_event_internal at ../c10/cuda/CUDACachingAllocator.cpp:1211 (most recent call first):                          
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7fe48196fd62 in /home/ed/.conda/envs/pg/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1c4d3 (0x7fe481bd24d3 in /home/ed/.conda/envs/pg/lib/python3.9/site-packages/torch/lib/libc10_cuda.so)
frame #2: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x1a2 (0x7fe481bd2ee2 in /home/ed/.conda/envs/pg/lib/python3.9/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10::TensorImpl::release_resources() + 0xa4 (0x7fe481959314 in /home/ed/.conda/envs/pg/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #4: std::vector<at::Tensor, std::allocator<at::Tensor> >::~vector() + 0x4a (0x7fe4de77a80a in /home/ed/.conda/envs/pg/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #5: c10d::ProcessGroupNCCL::barrier(c10d::BarrierOptions const&) + 0x9c6 (0x7fe4850be296 in /home/ed/.conda/envs/pg/lib/python3.9/site-packages/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0xe86443 (0x7fe4def1a443 in /home/ed/.conda/envs/pg/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #7: <unknown function> + 0x2a544b (0x7fe4de33944b in /home/ed/.conda/envs/pg/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x15ff85 (0x55cf0df03f85 in /home/ed/.conda/envs/pg/bin/python)                                                                  
frame #9: _PyObject_MakeTpCall + 0x316 (0x55cf0deea986 in /home/ed/.conda/envs/pg/bin/python)                                          
frame #10: <unknown function> + 0x1a225a (0x55cf0df4625a in /home/ed/.conda/envs/pg/bin/python)                                                                      
frame #11: _PyEval_EvalFrameDefault + 0x11e1 (0x55cf0df85141 in /home/ed/.conda/envs/pg/bin/python)                                              
frame #12: <unknown function> + 0x259205 (0x55cf0dffd205 in /home/ed/.conda/envs/pg/bin/python)                                                                                      
frame #13: _PyEval_EvalFrameDefault + 0x4c92 (0x55cf0df88bf2 in /home/ed/.conda/envs/pg/bin/python)                                                                               
frame #14: <unknown function> + 0x139c70 (0x55cf0deddc70 in /home/ed/.conda/envs/pg/bin/python)                                                                                   
frame #15: _PyFunction_Vectorcall + 0x336 (0x55cf0df45596 in /home/ed/.conda/envs/pg/bin/python)                                                                                  
frame #16: _PyObject_Call + 0xb5 (0x55cf0df04335 in /home/ed/.conda/envs/pg/bin/python)                                                                                           
frame #17: _PyEval_EvalFrameDefault + 0x2d69 (0x55cf0df86cc9 in /home/ed/.conda/envs/pg/bin/python)                                                                               
frame #18: _PyFunction_Vectorcall + 0x19a (0x55cf0df453fa in /home/ed/.conda/envs/pg/bin/python)                                                                                  
frame #19: _PyObject_Call + 0x10b (0x55cf0df0438b in /home/ed/.conda/envs/pg/bin/python)                                                                                          
frame #20: _PyEval_EvalFrameDefault + 0x2d69 (0x55cf0df86cc9 in /home/ed/.conda/envs/pg/bin/python)                                                                               
frame #21: _PyFunction_Vectorcall + 0x19a (0x55cf0df453fa in /home/ed/.conda/envs/pg/bin/python)                                                                                  
frame #22: _PyObject_Call + 0x10b (0x55cf0df0438b in /home/ed/.conda/envs/pg/bin/python)                                                                                          
frame #23: _PyEval_EvalFrameDefault + 0x2d69 (0x55cf0df86cc9 in /home/ed/.conda/envs/pg/bin/python)                                                                                         
frame #24: _PyFunction_Vectorcall + 0x19a (0x55cf0df453fa in /home/ed/.conda/envs/pg/bin/python)                                                                                            
frame #25: _PyEval_EvalFrameDefault + 0x609 (0x55cf0df84569 in /home/ed/.conda/envs/pg/bin/python)                                                                                          
frame #26: <unknown function> + 0x139430 (0x55cf0dedd430 in /home/ed/.conda/envs/pg/bin/python)                                                                                             
frame #27: _PyFunction_Vectorcall + 0x336 (0x55cf0df45596 in /home/ed/.conda/envs/pg/bin/python)                                                                                            
frame #28: _PyEval_EvalFrameDefault + 0x609 (0x55cf0df84569 in /home/ed/.conda/envs/pg/bin/python)                                                                                          
frame #29: _PyFunction_Vectorcall + 0x19a (0x55cf0df453fa in /home/ed/.conda/envs/pg/bin/python)                                                                                            
frame #30: _PyEval_EvalFrameDefault + 0x3bc (0x55cf0df8431c in /home/ed/.conda/envs/pg/bin/python)                                                                                          
frame #31: <unknown function> + 0x139430 (0x55cf0dedd430 in /home/ed/.conda/envs/pg/bin/python)                                                                                             
frame #32: _PyFunction_Vectorcall + 0x336 (0x55cf0df45596 in /home/ed/.conda/envs/pg/bin/python)                                                                                            
frame #33: _PyEval_EvalFrameDefault + 0x11e1 (0x55cf0df85141 in /home/ed/.conda/envs/pg/bin/python)                                                                                         
frame #34: <unknown function> + 0x139430 (0x55cf0dedd430 in /home/ed/.conda/envs/pg/bin/python)                                                                                             
frame #35: _PyEval_EvalCodeWithName + 0x47 (0x55cf0dfc3517 in /home/ed/.conda/envs/pg/bin/python)                                                                                           
frame #36: PyEval_EvalCodeEx + 0x39 (0x55cf0dfc3559 in /home/ed/.conda/envs/pg/bin/python)                                                                                                  
frame #37: PyEval_EvalCode + 0x1b (0x55cf0dfc357b in /home/ed/.conda/envs/pg/bin/python)                                                                                                    
frame #38: <unknown function> + 0x251ec9 (0x55cf0dff5ec9 in /home/ed/.conda/envs/pg/bin/python)                                                                                             
frame #39: <unknown function> + 0x28cc04 (0x55cf0e030c04 in /home/ed/.conda/envs/pg/bin/python)                                                                                             
frame #40: PyRun_StringFlags + 0x9d (0x55cf0e03702d in /home/ed/.conda/envs/pg/bin/python)                                                                                                  
frame #41: PyRun_SimpleStringFlags + 0x3d (0x55cf0e03708d in /home/ed/.conda/envs/pg/bin/python)                                                                                            
frame #42: Py_RunMain + 0x25c (0x55cf0e03730c in /home/ed/.conda/envs/pg/bin/python)                                                                                                        
frame #43: Py_BytesMain + 0x39 (0x55cf0e037599 in /home/ed/.conda/envs/pg/bin/python)                                                                                                       
frame #44: __libc_start_main + 0xd5 (0x7fe4e0fceb25 in /usr/lib/libc.so.6)                                                                                                                  
frame #45: <unknown function> + 0x20c6b1 (0x55cf0dfb06b1 in /home/ed/.conda/envs/pg/bin/python)                                                                                             
                                                                                                                                                                                            
/home/ed/.conda/envs/pg/lib/python3.9/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 34 leaked semaphore objects to clean up at shutdown        
  warnings.warn('resource_tracker: There appear to be %d '

Any ideas on how to fix this?

Got axes don't match array after writing the model

Hi,
I get the following issue while training the model on my dataset (just after the first trained model has been saved):

ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/content/projected_gan/training/dataset.py", line 102, in __getitem__
    image = self._load_raw_image(self._raw_idx[idx])
  File "/content/projected_gan/training/dataset.py", line 235, in _load_raw_image
    image = image.transpose(2, 0, 1) # HWC => CHW
ValueError: **axes don't match array**

I've first zipped my JPEG images with the dataset tools:

python dataset_tool.py --source=xxx --dest=yyyy.zip --resolution=256x256

Note that I'm using MacOS so maybe it may impact the way the images are zipped...

I'm using the following command line:

train( outdir='training-runs', cfg='fastgan', data='xxxx, gpus=1, batch=64, cond=False, mirror=1, batch_gpu=8, cbase=32768, cmax=512, glr=None, dlr=0.002, desc='', metrics=['fid50k_full'], kimg=10000, tick=4, snap=1, seed=0, workers=0 )

Thanks!

When will the Pretrained Model be published

Hi, great job from you guys!

I notice that Projected GAN got Global Rank #1 with most datasets, but most datasets are 256256, I wonder wether you guys will provide pretrained model on high-resolution dataset such as FFHQ 10241024 with both fastgan version and stylegan2 version.
It will be very useful for guys like me because of limited computation and data. :)

by the way I also notice that when we save best model, the code is only for metric fid50k_full. 0.0

save best fid ckpt

        snapshot_pkl = os.path.join(run_dir, f'best_model.pkl')
        cur_nimg_txt = os.path.join(run_dir, f'best_nimg.txt')
        if rank == 0:
            if 'fid50k_full' in stats_metrics and stats_metrics['fid50k_full'] < best_fid:
                best_fid = stats_metrics['fid50k_full']

                with open(snapshot_pkl, 'wb') as f:
                    dill.dump(snapshot_data, f)
                # save curr iteration number (directly saving it to pkl leads to problems with multi GPU)
                with open(cur_nimg_txt, 'w') as f:
                    f.write(str(cur_nimg))

Questions about "ngf"

Hello!

It was mentioned in a previous post that one could decrease memory use by:

"[reducing] the size of the FastGAN generator, e.g., set ngf=64"

I understand this decreases the size of the network layers, but will this have a noticeable effect on mid-resolution datasets (512p)? I am trying it regardless and will update with results, I was just wondering if anyone else has tried.

local variable 'feat_last' referenced before assignment

Both locally and in Paperspace, training fails after "Constructing networks..." with this error:

Traceback (most recent call last):
  File "/storage/projected_gan/train.py", line 266, in <module>
    main() # pylint: disable=no-value-for-parameter
  File "/opt/conda/envs/pg/lib/python3.9/site-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/envs/pg/lib/python3.9/site-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/opt/conda/envs/pg/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/envs/pg/lib/python3.9/site-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "/storage/projected_gan/train.py", line 252, in main
    launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)
  File "/storage/projected_gan/train.py", line 101, in launch_training
    subprocess_fn(rank=0, c=c, temp_dir=temp_dir)
  File "/storage/projected_gan/train.py", line 47, in subprocess_fn
    training_loop.training_loop(rank=rank, **c)
  File "/storage/projected_gan/training/training_loop.py", line 188, in training_loop
    img = misc.print_module_summary(G, [z, c])
  File "/storage/projected_gan/torch_utils/misc.py", line 216, in print_module_summary
    outputs = module(*inputs)
  File "/opt/conda/envs/pg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/storage/projected_gan/pg_modules/networks_fastgan.py", line 177, in forward
    img = self.synthesis(w, c)
  File "/opt/conda/envs/pg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/storage/projected_gan/pg_modules/networks_fastgan.py", line 81, in forward
    return self.to_big(feat_last)
UnboundLocalError: local variable 'feat_last' referenced before assignment

These are my params:

--outdir=./training-runs/ --cfg=fastgan --data=pokemon.zip --batch=16 --mirror=1 --snap=50 --kimg=10000 --gpus 1

How to reproduce the numbers in paper

Thanks for your nice work!

After running the official implementation of FFHQ256-full (using cfg stylegan2, 64 batch size in total, 10000 kimg, efficientnet lite0), the best FID is only 6.2, far away from the original StyleGAN2 and your paper. Is there something important that I missed?

Looking forward to your reply.

Using Pretrained model

Hi,
First, thanks for this very cool project.

I wanted to use a pretrained model in (.pkl), I got from : awesome-pretrained-stylegan2 .
And I got :
ImportError: cannot import name 'networks_stylegan2' from 'training' (unknown location)

Is there a solution or should I train it myself ?

Thanks in advance

Something strange in training

Hi Axel,
Thanks for your beautiful repo.
I tried your project, and add some modifications to this.
But I found something weird. For FastGAN, I can train the model with batch=6 on the 8G 2080. But when I use the ProjectGAN in your code, I cannot run with even batch=1.
Can you give me some advice for saving some space for training? (I use my model)

Thank you so much!
Best regards,
Wanglong

KeyError: 'G' when doing a preditction

Hi,

I had an error when trying to predict a output
Training seems fine, but when I try a prediction : python gen_images.py --outdir=out --trunc=1.0 --seeds=0-31 --network=my_model.pkl I got :

Traceback (most recent call last):
  File "G:\Python\projected_gan\gen_images.py", line 143, in <module>
    generate_images() # pylint: disable=no-value-for-parameter
  File "G:\Anaconda3\envs\fast_gan\lib\site-packages\click\core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "G:\Anaconda3\envs\fast_gan\lib\site-packages\click\core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "G:\Anaconda3\envs\fast_gan\lib\site-packages\click\core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "G:\Anaconda3\envs\fast_gan\lib\site-packages\click\core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "G:\Python\projected_gan\gen_images.py", line 108, in generate_images
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
  File "G:\Python\projected_gan\legacy.py", line 40, in load_network_pkl
    assert isinstance(data['G'], torch.nn.Module)
KeyError: 'G'

I can't figure out where the issue is.

Conditional GANS

Hello,
Do you see any limitations in trying to apply this for conditional gans?
I am looking into style transfer using gans. If when released please make it simple to apply to other areas, as I understand when skimming the paper most of the power is in the Discriminator correct?

No improvement in speed of convergence

Hi,

I have been experimenting SG2-ada, SG3 for my datasets of florals. I have had reasonable models trained on these datasets with OOTB hyper parameter settings. FIDs of 10-15. However, convergence isn't state of the art (could be because I haven't run the training for longer, typically for 1000-3000 kimg depending on if I used transfer learning or not).

I came across Projected GAN paper and it is very interesting. However, I haven't seen any convergence improvements as mentioned in the paper. To me, the convergence performance more or less seems same as SG2-ada surprisingly when tried for about 1000-1500 kimg from scratch. I am wondering if there is something fundamentally different with my dataset that makes using pre trained model in discriminator architecture ineffective. Is it possible that the feature space of these pre trained models is not optimal for my datasets. Is there a way to verify that?

I have also observed comparatively high memory usage with PG training. I was able to run with 32 per-gpu-batch size with SG2 ada. However, I am able to run only 8 per-gpu-batch sizes with PG.

More info
Dataset: different types of florals
Size: 4.5k-6k
Dimensions: 256x256
Infra: V100, Google Colab
Config: FASTGAN
Floral type 1: 17 FID after training for 14 hours, 1400 kimg
Floral type 2: 36 FID after training for 11 hours 1000 kimg

For both these datasets, I got similar FIDs trained for similar range of kimgs on sg2-ada. Overall, there is probably slight improvement in convergence speed (maybe 1.5x), but not significant.

I would be interested to know your thoughts on these observations.

Thanks.

Questions about discriminator

Hi. Thanks for your amazing work!

Why do you choose to use simple convolutional architecture as discriminators? For convergence speed? If the model will perform better when stronger discriminators are used?

how to use SingleDiscCond?

I noticed that there is a SingleDiscCond class in your code. Can this be used as a discriminator for classification and how to add loss to it

Information regarding FastGAN-Lite

Do you plan on releasing any supplementary material discussing the efficacy of the FastGAN-Lite configuration? In particular, I'm curious as to how small of a dataset (i.e., <5k, <10k...) does the FastGAN-Lite perform better than the FastGAN model, and what batch sizes does the FastGAN-Lite perform the best at? Thank you!

Conditional FastGAN

Has the conditional version of FastGAN been tested/should it show good results? I think in the paper there were no conditional experiments.
In my initial experiments it fails to learn anything meaningful; as a way to debug I also try to use the conditional architecture with just 1 class (so it should perform similarly to the unconditional one), but so far it also seems much worse.

In case of conditional training, I usually get
RuntimeWarning: overflow encountered in multiply
img = (img - lo) * (255 / (hi - lo))

which might suggest some large values in the generator's output.
I'm trying now applying tanh to the output in the conditional case (that is what BigGAN does which also uses cond batchnorm), so far the 1-cond learning curves look more similar to the unconditional case, but I only just started it

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.