danielroich / pti Goto Github PK
View Code? Open in Web Editor NEWOfficial Implementation for "Pivotal Tuning for Latent-based editing of Real Images" (ACM TOG 2022) https://arxiv.org/abs/2106.05744
License: MIT License
Official Implementation for "Pivotal Tuning for Latent-based editing of Real Images" (ACM TOG 2022) https://arxiv.org/abs/2106.05744
License: MIT License
Is this achievable? I want to take result and throw it in here - https://github.com/l4rz/stylegan2-clip-approach
Or is this misguided...it's going to destroy the heart of soul of this repo? When I run the projections from pytorch stylegan2-ada the interpretation is crap. not as good as this repo...
amazing work
It would be amazing if you can add a script to interpolate 2 faces.
Thank you for your great work @danielroich
this repo is the implementation of StyleCLIP's global direction methods text-based manipulation using Pivot tuning inversion
Hi, thanks for the impressive work, which has raised a great impact.
In the original paper, I wondered about the difference between SG, SG w+ and the first step of PTI.
Both SG, PTI optimize the original W space, while SG w+ employ the w+ space.
SG and SG w+ take more steps to optimize. And all the 3 methods employ noise regularization.
Is there any other difference that I missed?
Running the colab notebook throws an unpickling error at "Use PTI with e4e backbone for StyleCLIP"
---------------------------------------------------------------------------
UnpicklingError Traceback (most recent call last)
<ipython-input-38-29c3e7342ea3> in <module>()
1 hyperparameters.first_inv_type = 'w+'
2 os.chdir('/content/PTI')
----> 3 model_id = run_PTI(use_wandb=False, use_multi_id_training=False)
5 frames
/usr/local/lib/python3.7/dist-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
775 "functionality.")
776
--> 777 magic_number = pickle_module.load(f, **pickle_load_args)
778 if magic_number != MAGIC_NUMBER:
779 raise RuntimeError("Invalid magic number; corrupt file?")
UnpicklingError: invalid load key, '<'.
Is there any pretrained stylegan-ada model releated to car images?
when i download car-images pretrained styleGAN2 model from https://github.com/NVlabs/stylegan2 and run pti, i get the error "ModuleNotFoundError: No module named 'dnnlib.tflib'"
how can i solve it ?
Hello,
I have a question. When we fine-tune the generator, do we need to save a corresponding model parameter for each image?
Hi Daniel. Thank you for great work.
I have trained my StyleGAN2-ada model using sketch data, which generates sketches quite well.
After that, to manipulate real-images, I have tested PTI but the quality was not good.
When I remove LPIPS loss (using only L2 loss) running the pivotal tuning, the reconstruction went well.
However, the manipulation is still not working very well.
Could you please provide any tips on this?
Should I train LPIPS with different datasets? or any loss function you can recommend?
Thanks for your great work, as for the inversion step in the paper , i want to konw if there are any initialization for the W to look after the Wp ,like W average or just random vector?
Hello,
I find that the option of lpips_type = 'alex'
tend to affect my inversion results:
with lpips_type = 'alex'
, it introduces undesirable checkerboard/noise-like artifact which gives me a sense of overfitting (input left and reconstruction right)
with lpips=type='vgg'
, results are smoother which gives me a sense of underfitting:
you may need to zoom-in a bit to tell the difference.
any suggestions in this case? do I need to tune options like LPIPS_value_threshold = 0.06
to find a sweet point between this trade-off?
is it possible to do inference on cpu in colab?
hi, i have run the file run_PTI, but get the error: KeyError: 'FullyConnectedLayer'
detail:
KeyError Traceback (most recent call last)
in
----> 1 model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training)
~/ali_repo/PTI/scripts/run_pti.py in run_PTI(run_name, use_wandb, use_multi_id_training)
41 coach = MultiIDCoach(dataloader, use_wandb)
42 else:
---> 43 coach = SingleIDCoach(dataloader, use_wandb)
44
45 coach.train()
~/ali_repo/PTI/training/coaches/single_id_coach.py in init(self, data_loader, use_wandb)
10
11 def init(self, data_loader, use_wandb):
---> 12 super().init(data_loader, use_wandb)
13
14 def train(self):
~/ali_repo/PTI/training/coaches/base_coach.py in init(self, data_loader, use_wandb)
37 self.lpips_loss = LPIPS(net=hyperparameters.lpips_type).to(global_config.device).eval()
38
---> 39 self.restart_training()
40
41 # Initialize checkpoint dir
~/ali_repo/PTI/training/coaches/base_coach.py in restart_training(self)
46
47 # Initialize networks
---> 48 self.G = load_old_G()
49 toogle_grad(self.G, True)
50
~/ali_repo/PTI/utils/models_utils.py in load_old_G()
21 def load_old_G():
22 with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
---> 23 old_G = pickle.load(f)['G_ema'].to(global_config.device).eval()
24 old_G = old_G.float()
25 return old_G
~/ali_repo/PTI/torch_utils/persistence.py in _reconstruct_persistent_obj(meta)
191
192 assert meta.type == 'class'
--> 193 orig_class = module.dict[meta.class_name]
194 decorator_class = persistent_class(orig_class)
195 obj = decorator_class.new(decorator_class)
KeyError: 'FullyConnectedLayer'
Hey, So Hyperstyle saves weights when executed and afterward when I tune the inversion using PTI and then use styleGAN for output. The main issue I am facing is that styleGAN loads the saved weights from Hyperstyle and thus the editing is being done on Hyperstyle inversion and not Hyperstyle + PTI tuned inversion. So is there a way to use global directions only and save the weights after the tuning through PTI has been performed?
Pretrained models downloaded with the get_download_model_command() function (align.dat, ffhq.pkl, etc.) can silently fail, instead downloading a file that contains a rate limit error message (see below) and causing opaque errors later in execution. From googling, this rate limit seems to be related to number of downloads from the host.
Some potential resolutions:
I've managed to get through the PTI step at least, and the inversion quality is quite incredible. 🤯
<!DOCTYPE html><html><head><title>Google Drive - Quota exceeded</title><meta http-equiv="content-type" content="text/html; charset=utf-8"/><link href=/static/doclist/client/css/2149812255-untrustedcontent.css rel="stylesheet" nonce="HaMJevBe9x570XPxVsedUg"><link rel="icon" href="//ssl.gstatic.com/images/branding/product/1x/drive_2020q4_32dp.png"/><style nonce="HaMJevBe9x570XPxVsedUg">#gbar,#guser{font-size:13px;padding-top:0px !important;}#gbar{height:22px}#guser{padding-bottom:7px !important;text-align:right}.gbh,.gbd{border-top:1px solid #c9d7f1;font-size:1px}.gbh{height:0;position:absolute;top:24px;width:100%}@media all{.gb1{height:22px;margin-right:.5em;vertical-align:top}#gbar{float:left}}a.gb1,a.gb4{text-decoration:underline !important}a.gb1,a.gb4{color:#00c !important}.gbi .gb4{color:#dd8e27 !important}.gbf .gb4{color:#900 !important}
</style><script nonce="Cz9wYhCg7eMLb9sQpDKYQw"></script></head><body><div id=gbar><nobr><a target=_blank class=gb1 href="https://www.google.com/webhp?tab=ow">Search</a> <a target=_blank class=gb1 href="http://www.google.com/imghp?hl=en&tab=oi">Images</a> <a target=_blank class=gb1 href="https://maps.google.com/maps?hl=en&tab=ol">Maps</a> <a target=_blank class=gb1 href="https://play.google.com/?hl=en&tab=o8">Play</a> <a target=_blank class=gb1 href="https://www.youtube.com/?gl=US&tab=o1">YouTube</a> <a target=_blank class=gb1 href="https://news.google.com/?tab=on">News</a> <a target=_blank class=gb1 href="https://mail.google.com/mail/?tab=om">Gmail</a> <b class=gb1>Drive</b> <a target=_blank class=gb1 style="text-decoration:none" href="https://www.google.com/intl/en/about/products?tab=oh"><u>More</u> »</a></nobr></div><div id=guser width=100%><nobr><span id=gbn class=gbi></span><span id=gbf class=gbf></span><span id=gbe></span><a target="_self" href="/settings?hl=en_US" class=gb4>Settings</a> | <a target=_blank href="//support.google.com/drive/?p=web_home&hl=en_US" class=gb4>Help</a> | <a target=_top id=gb_70 href="https://accounts.google.com/ServiceLogin?hl=en&passive=true&continue=https://docs.google.com/uc%3Fexport%3Ddownload%26confirm%26id%3D1cUv_reLE6k3604or78EranS7XzuVMWeO&service=writely&ec=GAZAMQ" class=gb4>Sign in</a></nobr></div><div class=gbh style=left:0></div><div class=gbh style=right:0></div><div class="uc-main"><div id="uc-text"><p class="uc-error-caption">Sorry, you can't view or download this file at this time.</p><p class="uc-error-subcaption">Too many users have viewed or downloaded this file recently. Please try accessing the file again later. If the file you are trying to access is particularly large or is shared with many people, it may take up to 24 hours to be able to view or download the file. If you still can't access a file after 24 hours, contact your domain administrator.</p></div></div><div class="uc-footer"><hr class="uc-footer-divider">© 2021 Google - <a class="goog-link" href="//support.google.com/drive/?p=web_home">Help</a> - <a class="goog-link" href="//support.google.com/drive/bin/answer.py?hl=en_US&answer=2450387">Privacy & Terms</a></div></body></html>
Hi there 👋
Thanks a lot for the project, I trying to use run_pti
with an image but I got this error
(stylegan3) ➜ PTI git:(main) ✗ python scripts/run_pti.py
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
/home/zuppif/miniconda3/envs/stylegan3/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/home/zuppif/miniconda3/envs/stylegan3/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Loading model from: /home/zuppif/miniconda3/envs/stylegan3/lib/python3.9/site-packages/lpips/weights/v0.1/alex.pth
0%| | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/zuppif/Documents/DragGAN/PTI/scripts/run_pti.py", line 51, in <module>
run_PTI(run_name='', use_wandb=False, use_multi_id_training=False)
File "/home/zuppif/Documents/DragGAN/PTI/scripts/run_pti.py", line 45, in run_PTI
coach.train()
File "/home/zuppif/Documents/DragGAN/PTI/training/coaches/single_id_coach.py", line 39, in train
w_pivot = self.calc_inversions(image, image_name)
File "/home/zuppif/Documents/DragGAN/PTI/training/coaches/base_coach.py", line 93, in calc_inversions
w = w_projector.project(self.G, id_image, device=torch.device(global_config.device), w_avg_samples=600,
File "/home/zuppif/Documents/DragGAN/PTI/training/projectors/w_projector.py", line 41, in project
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
AssertionError
I first try to resize the image to 1024x512
then to 1024x1024
but the error persist.
Thanks a lot
Fra
Great work!
I notice that there is a "first_inv_type" option in hyperparameters.py. You must have tried using e4e in the first stage. Can you tell us why you choose to use the original projector in StyleGAN2 instead of e4e? Thanks!
the latest version 0.1.4 seems to break things with missing layer
Traceback (most recent call last):
File "scripts/notebook.py", line 63, in
model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training)
File "/home/jp/Documents/gitWorkspace/PTI/scripts/run_pti.py", line 40, in run_PTI
coach = SingleIDCoach(dataloader, use_wandb)
File "/home/jp/Documents/gitWorkspace/PTI/training/coaches/single_id_coach.py", line 13, in init
super().init(data_loader, use_wandb)
File "/home/jp/Documents/gitWorkspace/PTI/training/coaches/base_coach.py", line 36, in init
self.lpips_loss = LPIPS(net=hyperparameters.lpips_type, lpips_layers=hyperparameters.pt_lpips_layers).to(global_config.device).eval()
TypeError: init() got an unexpected keyword argument 'lpips_layers'
I'd submit a PR - but my branch has drifted considerably.
I'm trying to perform inversion on whole body images as opposed to faces. Looking at the inference notebook you shared, I'm guessing the preprocessing function which receives input from align_faces will need a new function called align_body (for example) to provide input for a body image.
Or will the best solution be to skip the preprocessing step all together?
When trying to run model on a full image, it gives only the face part. How can I run the model for full images
So I get the npz file (thanks for your help on the other ticket #24) + I see the new generator - saved.
model_MWVZTEZFDDJB_1.pt
I did some inspecting and see the new generator
https://gist.github.com/johndpope/c5b77f8cc7d7d008be7f15079a9378bf
I'm wanting to spit out am update ffhq pkl file in the correct shape and format so I can run the new generator in different use cases with other repos.
with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
old_G = pickle.load(f)['G_ema'].cuda() // this grabs the pickle for ffhq file
with open(f'{paths_config.checkpoints_dir}/model_{model_id}_{image_name}.pt', 'rb') as f_new:
new_G = torch.load(f_new).cuda() // and htis is grabbing the updated model_MWVZTEZFDDJB_1.pt
UPDATE 1 - thus far I have this hack which saves out a pkl
UPDATE 2 -
I actually load the new file into stylegan2-ada-pytorch and run the approach.py in conjunction with projected_w.pnz
but it's badly working - I wonder if it's because this pickle would need a new descriminator too???
UPDATE 3 -
I think I know how to solve - I need to load the final pt which is spat out and do the hot wiring - should be fine.
def export_updated_pickle(new_G,model_id):
print("Exporting large updated pickle based off new generator and ffhq.pkl")
with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
d = pickle.load(f)
old_G = d['G_ema'].cuda() ## tensor
old_D = d['D'].eval().requires_grad_(False).cpu()
tmp = {}
tmp['G_ema'] = old_G.eval().requires_grad_(False).cpu()# copy.deepcopy(new_G).eval().requires_grad_(False).cpu()
tmp['G'] = new_G.eval().requires_grad_(False).cpu() # copy.deepcopy(new_G).eval().requires_grad_(False).cpu()
tmp['D'] = old_D
tmp['training_set_kwargs'] = None
tmp['augment_pipe'] = None
with open(f'{paths_config.checkpoints_dir}/model_{model_id}.pkl', 'wb') as f:
pickle.dump(tmp, f)
....
at bottom of notebook
print(f'Displaying PTI inversion')
plot_image_from_w(w_pivot, new_G)
np.savez(f'projected_w.npz', w=w_pivot.cpu().detach().numpy())
export_updated_pickle(new_G,model_id)
https://drive.google.com/drive/folders/1l6Xvs6EPVyyw0sFowIpN1pd1lJbm56hD?usp=sharing
I get new pkl / npz file
I cherry pick this file into original stylegan2-ada-pytorch repo
https://github.com/l4rz/stylegan2-clip-approach
I rename file pkl to ffhq-pti.pkl
I run
(torch) ➜ stylegan2-ada-pytorch git:(main) ✗ python approach.py --network ffhq-pti.pkl --w projected_w.npz --outdir ffhq-pti --num-steps 100 --text 'squint'
hi, i got another error:
ModuleNotFoundError: No module named 'six.moves.collections_abc'
hi,
Are fine-tuned weights provided for testing?
I was thinking this would be fit for ffmpeg to handle this in an image sequence.
The code has very support for multiple images - this would be interesting to explore.
hi, i got a error and i search the google it maybe the torch version prob. could u provide your torch version and my torch version as below.
torch==1.3.0
torchsummary==1.5.1
torchvision==0.4.1
RuntimeError: ./pretrained_models/e4e_ffhq_encode.pt is a zip archive (did you mean to use torch.jit.load()?)
I got all the models. Environment is fine.
Code runs. I check the config…. Seems ok. But I don’t get anywhere using 2 scripts
scripts/run-pti.py
and evaluation python script.
I don’t see any output.
I check output directories / but I don’t get any files spat out. I have an aligned folder with sequence of images of faces…
Hi, thanks for your great work.
I notice that you said the editing directions you uploaded are trained on the pretrained StyleGAN. If I want more editing directions, what should I do?
Thank you.
hello,
I used stylegan2 to find some directions, but the effect is not fine.
Do you use stylegan2-ada to find directions through interfaceGan?
Thank you!
amazing work.
I would like to know if it is possible to generate random faces with a seed like nvidia's stylegan.
Try to do this, but the generated messages are full of artifacts.
new_G.synthesis(torch.from_numpy(np.random.rand(1,18,512)).float().to("cuda"),noise_mode='const')
hi guys, i want to make it test online.but it's too slow because the embedding generator need about 35 mins.
RuntimeError Traceback (most recent call last)
in
----> 1 predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
RuntimeError: Error deserializing object of type int64
while deserializing a floating point number.
while deserializing a dlib::matrix
while deserializing object of type std::vector
while deserializing object of type std::vector
while deserializing object of type std::vector
Dear Daniel,
Thank you very much for the great work.
I am trying to apply PTI on my trained model which has 256 * 256 resolution.
Could you give me any tips on which part should I fix in the code?
Many thanks,
I have two images aligned in / content / PTI / image_processed
I want to train a model for multiple images but using use_multi_id_training = True.
but it gives this error
100%|██████████| 2/2 [00:01<00:00, 1.39it/s]
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%
233M/233M [00:04<00:00, 50.4MB/s]
Loading model from: /usr/local/lib/python3.7/dist-packages/lpips/weights/v0.1/alex.pth
Downloading https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt ... done
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py:1051: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)
return forward_call(*input, **kwargs)
100%|██████████| 450/450 [01:21<00:00, 5.52it/s]
100%|██████████| 450/450 [01:21<00:00, 5.52it/s]
0%| | 0/350 [00:00<?, ?it/s]
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-33-b69631c99070> in <module>()
34 ## In order to run PTI and use StyleGAN2-ada, the cwd should the parent of 'torch_utils' and 'dnnlib'
35 os.chdir('/content/PTI')
---> 36 model_id = run_PTI(use_wandb=False, use_multi_id_training=True)
1 frames
/content/PTI/training/coaches/multi_id_coach.py in train(self)
58
59 self.optimizer.zero_grad()
---> 60 loss.backward()
61 self.optimizer.step()
62
AttributeError: 'tuple' object has no attribute 'backward'
Hi,
I'm running the colab as is and I always get this error:
RuntimeError Traceback (most recent call last)
<ipython-input-18-4cd7c7dc57be> in <module>()
----> 1 pre_process_images(f'/content/PTI/{image_dir_name}_original')
/content/PTI/utils/align_data.py in pre_process_images(raw_images_path)
12
13 IMAGE_SIZE = 1024
---> 14 predictor = dlib.shape_predictor(paths_config.dlib)
15 os.chdir(raw_images_path)
16 images_names = glob.glob(f'*')
RuntimeError: Error deserializing object of type int
I did not make any changes, perhaps there is something extra I need to configure?
Just letting you know.
Thanks
Hi, Thanks for your great work!
Here I get one question, when you fine-tune the parameters of the pre-trained StyleGAN model and use InterFaceGAN methods to edit the image, do we need to find a new semantic direction based on the current parameter model? Intuitively, changing the model parameters will change the semantic direction found by InterFaceGAN.
In general, after you fine-tune the pre-trained StyleGAN model, do you retrain the semantic direction found by InterFaceGAN?
Thanks
Thans for your contribution. How many images have been tried at most to train together? Ideally, each picture corresponds to one pivotal finetuned stylegan. I have 2w+ images with different IDs and quality of images are different. What if I train them together to get one stylegan?
Thanks for your great work. I find that in multi_id_coach.py, you will opt one img per step. why can you not opt several img in batch per step. Another question, if I have one person with several pose img, how to get a best invision res. one img one ckpt or one ckpt for all img. Thank you in advance.
PTI/training/projectors/w_projector.py
Line 47 in fb1c485
the code in colab only allows to use a mapper.
what I want is to write a text to modify the image as Styleclip.
Thanks for sharing your great work! When I run the code in colab, some errors come out in the part of downloading pre-trained models. It shows that "NameError: name 'downloader' is not defined". Can you give me any suggestions?
hi guys,
i run the align_data.py and i can't found the file below:
dlib = './pretrained_models/align.dat'
In your code "./configs/hyperparameters.py", the "use_locality_regularization" is False? The space_regulizer_loss is not calculated ?
Hi there!
I’ve been trying to invert some pictures using pre-trained models that don’t generate faces. For obvious reasons, I’ve been skipping most of the pre-processing, such as dlib face alignment, being the resizing the only part that I left.
However, both the final embedding and the fine-tuned model are of poor quality, either being distorted or blurred. It seems the repository is specifically designed for faces, so I was wondering if you could tell us any best practices or advice about pre-processing pictures that aren’t necessarily faces.
Thanks for the good work!
Regards
Thanks for your great work! Could you please share the original out-of-domain images you collected?
Hi, thanks for your great work!
I am curious about all the input parameters such as 'noise_mode', 'force_fp' of self.G.synthesis(w, noise_mode='const', force_fp32=True).
And I also want to know how to return the featrures of each layer in the self.G.
Thanks
Hello, I was trying to implement PTI into eg3d/loss.py at main · NVlabs/eg3d but I got some problems when calling the PTI/training/projectors at main · danielroich/PTI
So here is how I call w/w_plus projector (search for function pti_projector
):
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""Loss functions."""
import numpy as np
import torch
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import upfirdn2d
from training.dual_discriminator import filtered_resizing
#----------------------------------------------------------------------------
class Loss:
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass
raise NotImplementedError()
#----------------------------------------------------------------------------
# ---------------------- project image into latent space --------------------- #
# modified code from https://github.com/oneThousand1000/EG3D-projector/tree/master/eg3d/projector
from training.projector import w_plus_projector, w_projector
from torchvision import transforms
import copy
def pti_projector(cur_G, cur_c, cur_image, device, latent_type='w_plus'):
# # put image back to cpu for transforms
# image = cur_image.cpu()
# # normalize image
# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
# std=[0.5, 0.5, 0.5])
# id_image = normalize(image)
# id_image = torch.squeeze((id_image + 1) / 2, 0)
id_image = cur_image.to(device)
# c = c.to(device)
c = torch.reshape(cur_c, (1, 25)).to(device) # 25 is the camera pose dimension 16 + 9
G = cur_G
if latent_type == 'w_plus':
w = w_plus_projector.project(G, c, id_image, device=device, w_avg_samples=600)
else:
w = w_projector.project(G, c, id_image, device=device, w_avg_samples=600)
print('w shape: ', w.shape)
return w
class StyleGAN2Loss(Loss):
def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, neural_rendering_resolution_initial=64, neural_rendering_resolution_final=None, neural_rendering_resolution_fade_kimg=0, gpc_reg_fade_kimg=1000, gpc_reg_prob=None, dual_discrimination=False, filter_mode='antialiased'):
super().__init__()
self.device = device
self.G = G
self.D = D
self.augment_pipe = augment_pipe
self.r1_gamma = r1_gamma
self.style_mixing_prob = style_mixing_prob
self.pl_weight = pl_weight
self.pl_batch_shrink = pl_batch_shrink
self.pl_decay = pl_decay
self.pl_no_weight_grad = pl_no_weight_grad
self.pl_mean = torch.zeros([], device=device)
self.blur_init_sigma = blur_init_sigma
self.blur_fade_kimg = blur_fade_kimg
self.r1_gamma_init = r1_gamma_init
self.r1_gamma_fade_kimg = r1_gamma_fade_kimg
self.neural_rendering_resolution_initial = neural_rendering_resolution_initial
self.neural_rendering_resolution_final = neural_rendering_resolution_final
self.neural_rendering_resolution_fade_kimg = neural_rendering_resolution_fade_kimg
self.gpc_reg_fade_kimg = gpc_reg_fade_kimg
self.gpc_reg_prob = gpc_reg_prob
self.dual_discrimination = dual_discrimination
self.filter_mode = filter_mode
self.resample_filter = upfirdn2d.setup_filter([1,3,3,1], device=device)
self.blur_raw_target = True
assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1)
def run_G(self, z, c, swapping_prob, neural_rendering_resolution, update_emas=False):
if swapping_prob is not None:
c_swapped = torch.roll(c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand((c.shape[0], 1), device=c.device) < swapping_prob, c_swapped, c)
else:
c_gen_conditioning = torch.zeros_like(c)
ws = self.G.mapping(z, c_gen_conditioning, update_emas=update_emas)
if self.style_mixing_prob > 0:
with torch.autograd.profiler.record_function('style_mixing'):
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
gen_output = self.G.synthesis(ws, c, neural_rendering_resolution=neural_rendering_resolution, update_emas=update_emas)
return gen_output, ws
def run_D(self, img, c, blur_sigma=0, blur_sigma_raw=0, update_emas=False):
blur_size = np.floor(blur_sigma * 3)
if blur_size > 0:
with torch.autograd.profiler.record_function('blur'):
f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div(blur_sigma).square().neg().exp2()
img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum())
if self.augment_pipe is not None:
augmented_pair = self.augment_pipe(torch.cat([img['image'],
torch.nn.functional.interpolate(img['image_raw'], size=img['image'].shape[2:], mode='bilinear', antialias=True)],
dim=1))
img['image'] = augmented_pair[:, :img['image'].shape[1]]
img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], size=img['image_raw'].shape[2:], mode='bilinear', antialias=True)
logits = self.D(img, c, update_emas=update_emas)
return logits
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg):
assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
if self.G.rendering_kwargs.get('density_reg', 0) == 0:
phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase)
if self.r1_gamma == 0:
phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase)
blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0
r1_gamma = self.r1_gamma
alpha = min(cur_nimg / (self.gpc_reg_fade_kimg * 1e3), 1) if self.gpc_reg_fade_kimg > 0 else 1
swapping_prob = (1 - alpha) * 1 + alpha * self.gpc_reg_prob if self.gpc_reg_prob is not None else None
if self.neural_rendering_resolution_final is not None:
alpha = min(cur_nimg / (self.neural_rendering_resolution_fade_kimg * 1e3), 1)
neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * (1 - alpha) + self.neural_rendering_resolution_final * alpha))
else:
neural_rendering_resolution = self.neural_rendering_resolution_initial
real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=self.resample_filter, filter_mode=self.filter_mode)
if self.blur_raw_target:
blur_size = np.floor(blur_sigma * 3)
if blur_size > 0:
f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div(blur_sigma).square().neg().exp2()
real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum())
real_img = {'image': real_img, 'image_raw': real_img_raw}
# run PTI to get w/w_plus latent codes for real images
# print(real_img.shape, real_c.shape, gen_z.shape, gen_c.shape)
# torch.Size([8, 3, 512, 512]) torch.Size([8, 25]) torch.Size([8, 512]) torch.Size([8, 25])
# convert gen_z to real_z
batch_size = real_img['image'].shape[0]
real_z = []
for i in range(batch_size):
cur_img = real_img['image'][i]
cur_c = real_c[i]
cur_z = pti_projector(self.G, cur_c, cur_img, device=self.device)
real_z.append(cur_z)
real_z = torch.stack(real_z)
print('real_z', real_z.shape)
# Gmain: Maximize logits for generated images.
if phase in ['Gmain', 'Gboth']:
with torch.autograd.profiler.record_function('Gmain_forward'):
gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution)
gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
training_stats.report('Loss/scores/fake', gen_logits)
training_stats.report('Loss/signs/fake', gen_logits.sign())
loss_Gmain = torch.nn.functional.softplus(-gen_logits)
training_stats.report('Loss/G/loss', loss_Gmain)
with torch.autograd.profiler.record_function('Gmain_backward'):
loss_Gmain.mean().mul(gain).backward()
# Density Regularization
if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'l1':
if swapping_prob is not None:
c_swapped = torch.roll(gen_c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
else:
c_gen_conditioning = torch.zeros_like(gen_c)
ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
if self.style_mixing_prob > 0:
with torch.autograd.profiler.record_function('style_mixing'):
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * self.G.rendering_kwargs['density_reg_p_dist']
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
sigma_initial = sigma[:, :sigma.shape[1]//2]
sigma_perturbed = sigma[:, sigma.shape[1]//2:]
TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
TVloss.mul(gain).backward()
# Alternative density regularization
if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-detach':
if swapping_prob is not None:
c_swapped = torch.roll(gen_c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
else:
c_gen_conditioning = torch.zeros_like(gen_c)
ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front
perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
sigma_initial = sigma[:, :sigma.shape[1]//2]
sigma_perturbed = sigma[:, sigma.shape[1]//2:]
monotonic_loss = torch.relu(sigma_initial.detach() - sigma_perturbed).mean() * 10
monotonic_loss.mul(gain).backward()
if swapping_prob is not None:
c_swapped = torch.roll(gen_c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
else:
c_gen_conditioning = torch.zeros_like(gen_c)
ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
if self.style_mixing_prob > 0:
with torch.autograd.profiler.record_function('style_mixing'):
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp']
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
sigma_initial = sigma[:, :sigma.shape[1]//2]
sigma_perturbed = sigma[:, sigma.shape[1]//2:]
TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
TVloss.mul(gain).backward()
# Alternative density regularization
if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-fixed':
if swapping_prob is not None:
c_swapped = torch.roll(gen_c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
else:
c_gen_conditioning = torch.zeros_like(gen_c)
ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front
perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
sigma_initial = sigma[:, :sigma.shape[1]//2]
sigma_perturbed = sigma[:, sigma.shape[1]//2:]
monotonic_loss = torch.relu(sigma_initial - sigma_perturbed).mean() * 10
monotonic_loss.mul(gain).backward()
if swapping_prob is not None:
c_swapped = torch.roll(gen_c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
else:
c_gen_conditioning = torch.zeros_like(gen_c)
ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
if self.style_mixing_prob > 0:
with torch.autograd.profiler.record_function('style_mixing'):
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp']
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
sigma_initial = sigma[:, :sigma.shape[1]//2]
sigma_perturbed = sigma[:, sigma.shape[1]//2:]
TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
TVloss.mul(gain).backward()
# Dmain: Minimize logits for generated images.
loss_Dgen = 0
if phase in ['Dmain', 'Dboth']:
with torch.autograd.profiler.record_function('Dgen_forward'):
gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution, update_emas=True)
gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
training_stats.report('Loss/scores/fake', gen_logits)
training_stats.report('Loss/signs/fake', gen_logits.sign())
loss_Dgen = torch.nn.functional.softplus(gen_logits)
with torch.autograd.profiler.record_function('Dgen_backward'):
loss_Dgen.mean().mul(gain).backward()
# Dmain: Maximize logits for real images.
# Dr1: Apply R1 regularization.
if phase in ['Dmain', 'Dreg', 'Dboth']:
name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1'
with torch.autograd.profiler.record_function(name + '_forward'):
real_img_tmp_image = real_img['image'].detach().requires_grad_(phase in ['Dreg', 'Dboth'])
real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(phase in ['Dreg', 'Dboth'])
real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw}
real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
training_stats.report('Loss/scores/real', real_logits)
training_stats.report('Loss/signs/real', real_logits.sign())
loss_Dreal = 0
if phase in ['Dmain', 'Dboth']:
loss_Dreal = torch.nn.functional.softplus(-real_logits)
training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
loss_Dr1 = 0
if phase in ['Dreg', 'Dboth']:
if self.dual_discrimination:
with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image'], real_img_tmp['image_raw']], create_graph=True, only_inputs=True)
r1_grads_image = r1_grads[0]
r1_grads_image_raw = r1_grads[1]
r1_penalty = r1_grads_image.square().sum([1,2,3]) + r1_grads_image_raw.square().sum([1,2,3])
else: # single discrimination
with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image']], create_graph=True, only_inputs=True)
r1_grads_image = r1_grads[0]
r1_penalty = r1_grads_image.square().sum([1,2,3])
loss_Dr1 = r1_penalty * (r1_gamma / 2)
training_stats.report('Loss/r1_penalty', r1_penalty)
training_stats.report('Loss/D/reg', loss_Dr1)
with torch.autograd.profiler.record_function(name + '_backward'):
(loss_Dreal + loss_Dr1).mean().mul(gain).backward()
#----------------------------------------------------------------------------
and here is the modified projector scripts:
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Project given image to the latent space of pretrained network pickle."""
import copy
import os
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import dnnlib
import PIL
from camera_utils import LookAtPoseSampler
def project(
G,
c,
# outdir,
target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
*,
num_steps=1000,
w_avg_samples=10000,
initial_learning_rate=0.01,
initial_noise_factor=0.05,
lr_rampdown_length=0.25,
lr_rampup_length=0.05,
noise_ramp_length=0.75,
regularize_noise_weight=1e5,
verbose=False,
device: torch.device,
initial_w=None,
image_log_step=100,
# w_name: str
):
# os.makedirs(f'{outdir}/{w_name}_w_plus', exist_ok=True)
# outdir = f'{outdir}/{w_name}_w_plus'
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
def logprint(*args):
if verbose:
print(*args)
G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore
# Compute w stats.
w_avg_path = './w_avg.npy'
w_std_path = './w_std.npy'
if (not os.path.exists(w_avg_path)) or (not os.path.exists(w_std_path)):
print(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
# c_samples = c.repeat(w_avg_samples, 1)
# use avg look at point
camera_lookat_point = torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device)
cam2world_pose = LookAtPoseSampler.sample(3.14 / 2, 3.14 / 2, camera_lookat_point,
radius=G.rendering_kwargs['avg_camera_radius'], device=device)
focal_length = 4.2647 # FFHQ's FOV
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
c_samples = c_samples.repeat(w_avg_samples, 1)
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) # [N, L, C]
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
# print('save w_avg to ./w_avg.npy')
# np.save('./w_avg.npy',w_avg)
w_avg_tensor = torch.from_numpy(w_avg).cuda()
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
# np.save(w_avg_path, w_avg)
# np.save(w_std_path, w_std)
else:
# w_avg = np.load(w_avg_path)
# w_std = np.load(w_std_path)
raise Exception(' ')
# z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
# c_samples = c.repeat(w_avg_samples, 1)
# w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) # [N, L, C]
# w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
# w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
# w_avg_tensor = torch.from_numpy(w_avg).cuda()
# w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
start_w = initial_w if initial_w is not None else w_avg
# Setup noise inputs.
noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name}
# Load VGG16 feature detector.
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
# url = './networks/vgg16.pt'
with dnnlib.util.open_url(url) as f:
vgg16 = torch.jit.load(f, map_location=device).eval().to(device)
# Features for target image.
target_images = target.unsqueeze(0).to(device).to(torch.float32)
if target_images.shape[2] > 256:
target_images = F.interpolate(target_images, size=(256, 256), mode='area')
target_features = vgg16(target_images, resize_images=False, return_lpips=True)
start_w = np.repeat(start_w, G.backbone.mapping.num_ws, axis=1)
w_opt = torch.tensor(start_w, dtype=torch.float32, device=device,
requires_grad=True) # pylint: disable=not-callable
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999),
lr=0.1)
# Init noise.
for buf in noise_bufs.values():
buf[:] = torch.randn_like(buf)
buf.requires_grad = True
for step in tqdm(range(num_steps)):
# Learning rate schedule.
t = step / num_steps
w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
lr = initial_learning_rate * lr_ramp
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Synth images from opt_w.
w_noise = torch.randn_like(w_opt) * w_noise_scale
ws = (w_opt + w_noise)
synth_images = G.synthesis(ws,c, noise_mode='const')['image']
# if step % image_log_step == 0:
# with torch.no_grad():
# vis_img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
# PIL.Image.fromarray(vis_img[0].cpu().numpy(), 'RGB').save(f'{outdir}/{step}.png')
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
synth_images = (synth_images + 1) * (255 / 2)
if synth_images.shape[2] > 256:
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
# Features for synth images.
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
dist = (target_features - synth_features).square().sum()
# Noise regularization.
reg_loss = 0.0
for v in noise_bufs.values():
noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
while True:
reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
loss = dist + reg_loss * regularize_noise_weight
# if step % 10 == 0:
# with torch.no_grad():
# print({f'step {step}, first projection _{w_name}': loss.detach().cpu()})
# Step
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
# Normalize noise.
with torch.no_grad():
for buf in noise_bufs.values():
buf -= buf.mean()
buf *= buf.square().mean().rsqrt()
del G
return w_opt
I got errors as shown below:
Computing W midpoint and stddev using 600 samples...
0%| | 0/1000 [00:00<?, ?it/s]/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/nn/modules/module.py:1488: UserWarning: operator() profile_node %106 : int = prim::profile_ivalue(%104)
does not have profile information (Triggered internally at /opt/conda/conda-bld/pytorch_1674202356920/work/torch/csrc/jit/codegen/cuda/graph_fuser.cpp:105.)
return forward_call(*args, **kwargs)
0%| | 0/1000 [00:07<?, ?it/s]
Traceback (most recent call last):
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 396, in <module>
main() # pylint: disable=no-value-for-parameter
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1128, in __call__
return self.main(*args, **kwargs)
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1053, in main
rv = self.invoke(ctx)
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 754, in invoke
return __callback(*args, **kwargs)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 391, in main
launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 101, in launch_training
subprocess_fn(rank=0, c=c, temp_dir=temp_dir)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 52, in subprocess_fn
training_loop.training_loop(rank=rank, **c)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/training_loop.py", line 286, in training_loop
loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/loss.py", line 156, in accumulate_gradients
cur_z = pti_projector(self.G, cur_c, cur_img, device=self.device)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/loss.py", line 49, in pti_projector
w = w_plus_projector.project(G, c, id_image, device=device, w_avg_samples=600)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/projector/w_plus_projector.py", line 171, in project
loss.backward()
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
torch.autograd.backward(
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/function.py", line 275, in apply
return user_fn(self, *args)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/torch_utils/ops/conv2d_gradfix.py", line 146, in backward
grad_weight = Conv2dGradWeight.apply(grad_output, input, weight)
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/function.py", line 508, in apply
return super().apply(*args, **kwargs)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/torch_utils/ops/conv2d_gradfix.py", line 173, in forward
return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1]
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/_ops.py", line 499, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA__convolution_backward)
seems like the error happens at loss.backward()
and I checked most of the variables/loss/model to make sure they are on cuda:0
. but I still got no luck to solve this. Do you know how to make loss backpropagate properly?
Well done. Your work is amazing. The performance is the best as far as I know. But I found inference time is slow, and every image should save 127M pt file. It is not pratical. Some advise is welcom.
Line 25 in b4964aa
I am using a self-trained model trained using the StyleGan-ada pytorch repository.
While using use_multi_id_training=True I get a size mismatch error in the forward call of G.synthesis.
The full trace is shown:
Traceback (most recent call last):
File "test.py", line 6, in
run_PTI(use_multi_id_training=True)
File "/home/usman/Documents/Work/PTI/scripts/run_pti.py", line 44, in run_PTI
coach.train()
File "/home/usman/Documents/Work/PTI/training/coaches/multi_id_coach.py", line 58, in train
generated_images = self.forward(w_pivot)
File "/home/usman/Documents/Work/PTI/training/coaches/base_coach.py", line 130, in forward
generated_images = self.G.synthesis(w, noise_mode='const', force_fp32=True)
File "/home/usman/anaconda3/envs/eg3d/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "", line 460, in forward
File "/home/usman/Documents/Work/PTI/torch_utils/misc.py", line 93, in assert_shape
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
AssertionError: Wrong size for dimension 1: got 18, expected 16
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.