This implementation is adapted from the stylegan2 codebase by Matthias Wright.
Specifically, the features we've added allow for better scaling of StyleGAN2 training on TPUs:
- ๐ญ Enable data-parallel training on TPU pods (tested on TPU v2 to v4 generations)
- ๐พ Google Cloud Storage (GCS) integration/dataset sharding between workers
- ๐ Quality-of-life improvements (e.g. improved W&B logging)
This food does not exist! Click to see more samples ๐ช๐ฐ๐ฃ๐น
- Clone the repository:
git clone https://github.com/nyx-ai/stylegan2-flax-tpu.git
- Go into the directory:
cd stylegan2-flax-tpu
- Install Jax according to your platform.
- Install requirements:
pip install -r requirements.txt
We released four 256x256 pretrained models: cookie, cheesecake, sushi and cocktail. Download them from the latest release.
python generate_images.py \
--checkpoint checkpoints/cookie-256.pkl \
--seeds 0 42 420 666 \
--truncation_psi 0.7 \
--out_path generated_images
Check the Colab notebook for more examples:
Add your images into a folder /path/to/image_dir
:
/path/to/image_dir/
0.jpg
1.jpg
2.jpg
4.jpg
...
and create a TFRecord dataset:
python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord
For more detailed instructions please refer to this README.
The following command trains with 128 resolution and batch size of 8.
python main.py --data_dir /path/to/tfrecord
Read more about suitable training parameters here.
- This work is based on Matthias Wright's stylegan2 implementation.
- The project received generous support from Google's TPU Research Cloud (TRC).
- The image datasets were built using the LAION5B index
- We are grateful to Weights & Biases for preserving our sanity