Giter Club home page Giter Club logo

pti's People

Contributors

danielroich avatar seemirra avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pti's Issues

interpolate faces

amazing work
It would be amazing if you can add a script to interpolate 2 faces.

SG and SG2 issue

Hi, thanks for the impressive work, which has raised a great impact.
In the original paper, I wondered about the difference between SG, SG w+ and the first step of PTI.
Both SG, PTI optimize the original W space, while SG w+ employ the w+ space.
SG and SG w+ take more steps to optimize. And all the 3 methods employ noise regularization.
Is there any other difference that I missed?

Colab Unpickling Error

Running the colab notebook throws an unpickling error at "Use PTI with e4e backbone for StyleCLIP"

---------------------------------------------------------------------------
UnpicklingError                           Traceback (most recent call last)
<ipython-input-38-29c3e7342ea3> in <module>()
      1 hyperparameters.first_inv_type = 'w+'
      2 os.chdir('/content/PTI')
----> 3 model_id = run_PTI(use_wandb=False, use_multi_id_training=False)

5 frames
/usr/local/lib/python3.7/dist-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
    775             "functionality.")
    776 
--> 777     magic_number = pickle_module.load(f, **pickle_load_args)
    778     if magic_number != MAGIC_NUMBER:
    779         raise RuntimeError("Invalid magic number; corrupt file?")

UnpicklingError: invalid load key, '<'.

one image one checkpoint?

Hello,
I have a question. When we fine-tune the generator, do we need to save a corresponding model parameter for each image?

Pivotal tuning for grey-scale (sketch) images

Hi Daniel. Thank you for great work.
I have trained my StyleGAN2-ada model using sketch data, which generates sketches quite well.
After that, to manipulate real-images, I have tested PTI but the quality was not good.
When I remove LPIPS loss (using only L2 loss) running the pivotal tuning, the reconstruction went well.
However, the manipulation is still not working very well.
Could you please provide any tips on this?
Should I train LPIPS with different datasets? or any loss function you can recommend?

About the initialization of the Inversion

Thanks for your great work, as for the inversion step in the paper , i want to konw if there are any initialization for the W to look after the Wp ,like W average or just random vector?

Effect of lpips type

Hello,

I find that the option of lpips_type = 'alex' tend to affect my inversion results:
with lpips_type = 'alex', it introduces undesirable checkerboard/noise-like artifact which gives me a sense of overfitting (input left and reconstruction right)
image
image
with lpips=type='vgg', results are smoother which gives me a sense of underfitting:
image
image
you may need to zoom-in a bit to tell the difference.
any suggestions in this case? do I need to tune options like LPIPS_value_threshold = 0.06 to find a sweet point between this trade-off?

KeyError: 'FullyConnectedLayer'

hi, i have run the file run_PTI, but get the error: KeyError: 'FullyConnectedLayer'
detail:
KeyError Traceback (most recent call last)
in
----> 1 model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training)

~/ali_repo/PTI/scripts/run_pti.py in run_PTI(run_name, use_wandb, use_multi_id_training)
41 coach = MultiIDCoach(dataloader, use_wandb)
42 else:
---> 43 coach = SingleIDCoach(dataloader, use_wandb)
44
45 coach.train()

~/ali_repo/PTI/training/coaches/single_id_coach.py in init(self, data_loader, use_wandb)
10
11 def init(self, data_loader, use_wandb):
---> 12 super().init(data_loader, use_wandb)
13
14 def train(self):

~/ali_repo/PTI/training/coaches/base_coach.py in init(self, data_loader, use_wandb)
37 self.lpips_loss = LPIPS(net=hyperparameters.lpips_type).to(global_config.device).eval()
38
---> 39 self.restart_training()
40
41 # Initialize checkpoint dir

~/ali_repo/PTI/training/coaches/base_coach.py in restart_training(self)
46
47 # Initialize networks
---> 48 self.G = load_old_G()
49 toogle_grad(self.G, True)
50

~/ali_repo/PTI/utils/models_utils.py in load_old_G()
21 def load_old_G():
22 with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
---> 23 old_G = pickle.load(f)['G_ema'].to(global_config.device).eval()
24 old_G = old_G.float()
25 return old_G

~/ali_repo/PTI/torch_utils/persistence.py in _reconstruct_persistent_obj(meta)
191
192 assert meta.type == 'class'
--> 193 orig_class = module.dict[meta.class_name]
194 decorator_class = persistent_class(orig_class)
195 obj = decorator_class.new(decorator_class)

KeyError: 'FullyConnectedLayer'

Style Clip editing on Hyperstyle + PTI output using global directions

Hey, So Hyperstyle saves weights when executed and afterward when I tune the inversion using PTI and then use styleGAN for output. The main issue I am facing is that styleGAN loads the saved weights from Hyperstyle and thus the editing is being done on Hyperstyle inversion and not Hyperstyle + PTI tuned inversion. So is there a way to use global directions only and save the weights after the tuning through PTI has been performed?

(Colab) Files downloaded with get_download_model_command() subject to GDrive rate limit

Pretrained models downloaded with the get_download_model_command() function (align.dat, ffhq.pkl, etc.) can silently fail, instead downloading a file that contains a rate limit error message (see below) and causing opaque errors later in execution. From googling, this rate limit seems to be related to number of downloads from the host.

Some potential resolutions:

  • use a new host
  • have it detect the failure and give a more helpful error message
  • do nothing and the problem will naturally go away as engagement decreases

I've managed to get through the PTI step at least, and the inversion quality is quite incredible. 🤯

<!DOCTYPE html><html><head><title>Google Drive - Quota exceeded</title><meta http-equiv="content-type" content="text/html; charset=utf-8"/><link href=&#47;static&#47;doclist&#47;client&#47;css&#47;2149812255&#45;untrustedcontent.css rel="stylesheet" nonce="HaMJevBe9x570XPxVsedUg"><link rel="icon" href="//ssl.gstatic.com/images/branding/product/1x/drive_2020q4_32dp.png"/><style nonce="HaMJevBe9x570XPxVsedUg">#gbar,#guser{font-size:13px;padding-top:0px !important;}#gbar{height:22px}#guser{padding-bottom:7px !important;text-align:right}.gbh,.gbd{border-top:1px solid #c9d7f1;font-size:1px}.gbh{height:0;position:absolute;top:24px;width:100%}@media all{.gb1{height:22px;margin-right:.5em;vertical-align:top}#gbar{float:left}}a.gb1,a.gb4{text-decoration:underline !important}a.gb1,a.gb4{color:#00c !important}.gbi .gb4{color:#dd8e27 !important}.gbf .gb4{color:#900 !important}
</style><script nonce="Cz9wYhCg7eMLb9sQpDKYQw"></script></head><body><div id=gbar><nobr><a target=_blank class=gb1 href="https://www.google.com/webhp?tab=ow">Search</a> <a target=_blank class=gb1 href="http://www.google.com/imghp?hl=en&tab=oi">Images</a> <a target=_blank class=gb1 href="https://maps.google.com/maps?hl=en&tab=ol">Maps</a> <a target=_blank class=gb1 href="https://play.google.com/?hl=en&tab=o8">Play</a> <a target=_blank class=gb1 href="https://www.youtube.com/?gl=US&tab=o1">YouTube</a> <a target=_blank class=gb1 href="https://news.google.com/?tab=on">News</a> <a target=_blank class=gb1 href="https://mail.google.com/mail/?tab=om">Gmail</a> <b class=gb1>Drive</b> <a target=_blank class=gb1 style="text-decoration:none" href="https://www.google.com/intl/en/about/products?tab=oh"><u>More</u> &raquo;</a></nobr></div><div id=guser width=100%><nobr><span id=gbn class=gbi></span><span id=gbf class=gbf></span><span id=gbe></span><a target="_self" href="/settings?hl=en_US" class=gb4>Settings</a> | <a target=_blank  href="//support.google.com/drive/?p=web_home&hl=en_US" class=gb4>Help</a> | <a target=_top id=gb_70 href="https://accounts.google.com/ServiceLogin?hl=en&passive=true&continue=https://docs.google.com/uc%3Fexport%3Ddownload%26confirm%26id%3D1cUv_reLE6k3604or78EranS7XzuVMWeO&service=writely&ec=GAZAMQ" class=gb4>Sign in</a></nobr></div><div class=gbh style=left:0></div><div class=gbh style=right:0></div><div class="uc-main"><div id="uc-text"><p class="uc-error-caption">Sorry, you can&#39;t view or download this file at this time.</p><p class="uc-error-subcaption">Too many users have viewed or downloaded this file recently. Please try accessing the file again later. If the file you are trying to access is particularly large or is shared with many people, it may take up to 24 hours to be able to view or download the file. If you still can't access a file after 24 hours, contact your domain administrator.</p></div></div><div class="uc-footer"><hr class="uc-footer-divider">&copy; 2021 Google - <a class="goog-link" href="//support.google.com/drive/?p=web_home">Help</a> - <a class="goog-link" href="//support.google.com/drive/bin/answer.py?hl=en_US&amp;answer=2450387">Privacy & Terms</a></div></body></html>

run_pit: assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)

Hi there 👋

Thanks a lot for the project, I trying to use run_pti with an image but I got this error

(stylegan3) ➜  PTI git:(main) ✗ python scripts/run_pti.py
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
/home/zuppif/miniconda3/envs/stylegan3/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/zuppif/miniconda3/envs/stylegan3/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Loading model from: /home/zuppif/miniconda3/envs/stylegan3/lib/python3.9/site-packages/lpips/weights/v0.1/alex.pth
  0%|                                                                             | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/zuppif/Documents/DragGAN/PTI/scripts/run_pti.py", line 51, in <module>
    run_PTI(run_name='', use_wandb=False, use_multi_id_training=False)
  File "/home/zuppif/Documents/DragGAN/PTI/scripts/run_pti.py", line 45, in run_PTI
    coach.train()
  File "/home/zuppif/Documents/DragGAN/PTI/training/coaches/single_id_coach.py", line 39, in train
    w_pivot = self.calc_inversions(image, image_name)
  File "/home/zuppif/Documents/DragGAN/PTI/training/coaches/base_coach.py", line 93, in calc_inversions
    w = w_projector.project(self.G, id_image, device=torch.device(global_config.device), w_avg_samples=600,
  File "/home/zuppif/Documents/DragGAN/PTI/training/projectors/w_projector.py", line 41, in project
    assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
AssertionError

I first try to resize the image to 1024x512 then to 1024x1024 but the error persist.

Thanks a lot

Fra

Why don't you use e4e in the first stage Inversion?

Great work!
I notice that there is a "first_inv_type" option in hyperparameters.py. You must have tried using e4e in the first stage. Can you tell us why you choose to use the original projector in StyleGAN2 instead of e4e? Thanks!

lpips seems broken -

the latest version 0.1.4 seems to break things with missing layer

Traceback (most recent call last):
File "scripts/notebook.py", line 63, in
model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training)
File "/home/jp/Documents/gitWorkspace/PTI/scripts/run_pti.py", line 40, in run_PTI
coach = SingleIDCoach(dataloader, use_wandb)
File "/home/jp/Documents/gitWorkspace/PTI/training/coaches/single_id_coach.py", line 13, in init
super().init(data_loader, use_wandb)
File "/home/jp/Documents/gitWorkspace/PTI/training/coaches/base_coach.py", line 36, in init
self.lpips_loss = LPIPS(net=hyperparameters.lpips_type, lpips_layers=hyperparameters.pt_lpips_layers).to(global_config.device).eval()
TypeError: init() got an unexpected keyword argument 'lpips_layers'

I'd submit a PR - but my branch has drifted considerably.

inversion on whole body images

I'm trying to perform inversion on whole body images as opposed to faces. Looking at the inference notebook you shared, I'm guessing the preprocessing function which receives input from align_faces will need a new function called align_body (for example) to provide input for a body image.

Or will the best solution be to skip the preprocessing step all together?

How to take the generated model pt -> and mash it back into stylegan2-ada pkl format for use in other apps. [DRAFTED]

So I get the npz file (thanks for your help on the other ticket #24) + I see the new generator - saved.
model_MWVZTEZFDDJB_1.pt

I did some inspecting and see the new generator
https://gist.github.com/johndpope/c5b77f8cc7d7d008be7f15079a9378bf

I'm wanting to spit out am update ffhq pkl file in the correct shape and format so I can run the new generator in different use cases with other repos.

  with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
    old_G = pickle.load(f)['G_ema'].cuda()  // this grabs the pickle for ffhq file
    
  with open(f'{paths_config.checkpoints_dir}/model_{model_id}_{image_name}.pt', 'rb') as f_new: 
    new_G = torch.load(f_new).cuda() // and htis is grabbing the updated model_MWVZTEZFDDJB_1.pt

UPDATE 1 - thus far I have this hack which saves out a pkl

UPDATE 2 -
I actually load the new file into stylegan2-ada-pytorch and run the approach.py in conjunction with projected_w.pnz
but it's badly working - I wonder if it's because this pickle would need a new descriminator too???

UPDATE 3 -
I think I know how to solve - I need to load the final pt which is spat out and do the hot wiring - should be fine.

def export_updated_pickle(new_G,model_id):
  print("Exporting large updated pickle based off new generator and ffhq.pkl")
  with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
    d = pickle.load(f)
    old_G = d['G_ema'].cuda() ## tensor
    old_D = d['D'].eval().requires_grad_(False).cpu()

  tmp = {}
  tmp['G_ema'] = old_G.eval().requires_grad_(False).cpu()# copy.deepcopy(new_G).eval().requires_grad_(False).cpu()
  tmp['G'] = new_G.eval().requires_grad_(False).cpu() # copy.deepcopy(new_G).eval().requires_grad_(False).cpu()
  tmp['D'] = old_D
  tmp['training_set_kwargs'] = None
  tmp['augment_pipe'] = None


  with open(f'{paths_config.checkpoints_dir}/model_{model_id}.pkl', 'wb') as f:
      pickle.dump(tmp, f)

....
at bottom of notebook
print(f'Displaying PTI inversion')
plot_image_from_w(w_pivot, new_G)
np.savez(f'projected_w.npz', w=w_pivot.cpu().detach().numpy())
export_updated_pickle(new_G,model_id)

original
image_from_w

1_afro
1_angry
1_bobcut
1_bowlcut
1_mohawk
1_surprised
1_trump

https://drive.google.com/drive/folders/1l6Xvs6EPVyyw0sFowIpN1pd1lJbm56hD?usp=sharing

I get new pkl / npz file

I cherry pick this file into original stylegan2-ada-pytorch repo
https://github.com/l4rz/stylegan2-clip-approach

I rename file pkl to ffhq-pti.pkl
I run
(torch) ➜ stylegan2-ada-pytorch git:(main) ✗ python approach.py --network ffhq-pti.pkl --w projected_w.npz --outdir ffhq-pti --num-steps 100 --text 'squint'

checkpoint

hi,
Are fine-tuned weights provided for testing?

Noob - question missing something

I got all the models. Environment is fine.
Code runs. I check the config…. Seems ok. But I don’t get anywhere using 2 scripts

scripts/run-pti.py

and evaluation python script.

5C13D617-B277-4B07-AF6C-A61770882526

I don’t see any output.
I check output directories / but I don’t get any files spat out. I have an aligned folder with sequence of images of faces…

7EA68D9E-432C-4251-93DF-998235757354

how to get more editing directions

Hi, thanks for your great work.
I notice that you said the editing directions you uploaded are trained on the pretrained StyleGAN. If I want more editing directions, what should I do?
Thank you.

How did you find the directions?

hello,
I used stylegan2 to find some directions, but the effect is not fine.
Do you use stylegan2-ada to find directions through interfaceGan?
Thank you!

generate random faces

amazing work.
I would like to know if it is possible to generate random faces with a seed like nvidia's stylegan.
Try to do this, but the generated messages are full of artifacts.
new_G.synthesis(torch.from_numpy(np.random.rand(1,18,512)).float().to("cuda"),noise_mode='const')

run align_data error


RuntimeError Traceback (most recent call last)
in
----> 1 predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)

RuntimeError: Error deserializing object of type int64
while deserializing a floating point number.
while deserializing a dlib::matrix
while deserializing object of type std::vector
while deserializing object of type std::vector
while deserializing object of type std::vector

Run PTI on 256*256 images

Dear Daniel,
Thank you very much for the great work.
I am trying to apply PTI on my trained model which has 256 * 256 resolution.
Could you give me any tips on which part should I fix in the code?
Many thanks,

use_multi_id_training=True error

I have two images aligned in / content / PTI / image_processed
I want to train a model for multiple images but using use_multi_id_training = True.
but it gives this error

100%|██████████| 2/2 [00:01<00:00,  1.39it/s]
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%
233M/233M [00:04<00:00, 50.4MB/s]

Loading model from: /usr/local/lib/python3.7/dist-packages/lpips/weights/v0.1/alex.pth
Downloading https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt ... done
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py:1051: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return forward_call(*input, **kwargs)
100%|██████████| 450/450 [01:21<00:00,  5.52it/s]
100%|██████████| 450/450 [01:21<00:00,  5.52it/s]
  0%|          | 0/350 [00:00<?, ?it/s]
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-33-b69631c99070> in <module>()
     34 ## In order to run PTI and use StyleGAN2-ada, the cwd should the parent of 'torch_utils' and 'dnnlib'
     35 os.chdir('/content/PTI')
---> 36 model_id = run_PTI(use_wandb=False, use_multi_id_training=True)

1 frames
/content/PTI/training/coaches/multi_id_coach.py in train(self)
     58 
     59                 self.optimizer.zero_grad()
---> 60                 loss.backward()
     61                 self.optimizer.step()
     62 

AttributeError: 'tuple' object has no attribute 'backward'

Error in Colab

Hi,

I'm running the colab as is and I always get this error:

pre_process_images(f'/content/PTI/{image_dir_name}_original')

RuntimeError                              Traceback (most recent call last)
<ipython-input-18-4cd7c7dc57be> in <module>()
----> 1 pre_process_images(f'/content/PTI/{image_dir_name}_original')

/content/PTI/utils/align_data.py in pre_process_images(raw_images_path)
     12 
     13     IMAGE_SIZE = 1024
---> 14     predictor = dlib.shape_predictor(paths_config.dlib)
     15     os.chdir(raw_images_path)
     16     images_names = glob.glob(f'*')
RuntimeError: Error deserializing object of type int

I did not make any changes, perhaps there is something extra I need to configure?
Just letting you know.

Thanks

Edit direction

Hi, Thanks for your great work!

Here I get one question, when you fine-tune the parameters of the pre-trained StyleGAN model and use InterFaceGAN methods to edit the image, do we need to find a new semantic direction based on the current parameter model? Intuitively, changing the model parameters will change the semantic direction found by InterFaceGAN.

In general, after you fine-tune the pre-trained StyleGAN model, do you retrain the semantic direction found by InterFaceGAN?

Thanks

How many images have been tried at most to train together?

Thans for your contribution. How many images have been tried at most to train together? Ideally, each picture corresponds to one pivotal finetuned stylegan. I have 2w+ images with different IDs and quality of images are different. What if I train them together to get one stylegan?

opt multi image one time?

Thanks for your great work. I find that in multi_id_coach.py, you will opt one img per step. why can you not opt several img in batch per step. Another question, if I have one person with several pose img, how to get a best invision res. one img one ckpt or one ckpt for all img. Thank you in advance.

edit an image with text?

the code in colab only allows to use a mapper.
what I want is to write a text to modify the image as Styleclip.

errors in colab

Thanks for sharing your great work! When I run the code in colab, some errors come out in the part of downloading pre-trained models. It shows that "NameError: name 'downloader' is not defined". Can you give me any suggestions?

[question] What kind of pre-processing would a model that doesn't generate faces require?

Hi there!

I’ve been trying to invert some pictures using pre-trained models that don’t generate faces. For obvious reasons, I’ve been skipping most of the pre-processing, such as dlib face alignment, being the resizing the only part that I left.

However, both the final embedding and the fine-tuned model are of poor quality, either being distorted or blurred. It seems the repository is specifically designed for faces, so I was wondering if you could tell us any best practices or advice about pre-processing pictures that aren’t necessarily faces.

Thanks for the good work!

Regards

The input parameters of the Generator

Hi, thanks for your great work!

I am curious about all the input parameters such as 'noise_mode', 'force_fp' of self.G.synthesis(w, noise_mode='const', force_fp32=True).
And I also want to know how to return the featrures of each layer in the self.G.

Thanks

GPU error

Hello, I was trying to implement PTI into eg3d/loss.py at main · NVlabs/eg3d but I got some problems when calling the PTI/training/projectors at main · danielroich/PTI

So here is how I call w/w_plus projector (search for function pti_projector):

eg3d code
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

"""Loss functions."""

import numpy as np
import torch
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import upfirdn2d
from training.dual_discriminator import filtered_resizing

#----------------------------------------------------------------------------

class Loss:
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass
        raise NotImplementedError()

#----------------------------------------------------------------------------

# ---------------------- project image into latent space --------------------- #
# modified code from https://github.com/oneThousand1000/EG3D-projector/tree/master/eg3d/projector
from training.projector import w_plus_projector, w_projector
from torchvision import transforms
import copy

def pti_projector(cur_G, cur_c, cur_image, device, latent_type='w_plus'):
    # # put image back to cpu for transforms
    # image = cur_image.cpu()
    # # normalize image
    # normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
    #                                     std=[0.5, 0.5, 0.5])
    # id_image = normalize(image)
    # id_image = torch.squeeze((id_image + 1) / 2, 0)
    
    id_image = cur_image.to(device)
    # c = c.to(device)
    c = torch.reshape(cur_c, (1, 25)).to(device) # 25 is the camera pose dimension 16 + 9
    G = cur_G

    if latent_type == 'w_plus':
        w = w_plus_projector.project(G, c, id_image, device=device, w_avg_samples=600)
    else:
        w = w_projector.project(G, c, id_image, device=device, w_avg_samples=600)
    print('w shape: ', w.shape)
    return w
    

class StyleGAN2Loss(Loss):
    def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, neural_rendering_resolution_initial=64, neural_rendering_resolution_final=None, neural_rendering_resolution_fade_kimg=0, gpc_reg_fade_kimg=1000, gpc_reg_prob=None, dual_discrimination=False, filter_mode='antialiased'):
        super().__init__()
        self.device             = device
        self.G                  = G
        self.D                  = D
        self.augment_pipe       = augment_pipe
        self.r1_gamma           = r1_gamma
        self.style_mixing_prob  = style_mixing_prob
        self.pl_weight          = pl_weight
        self.pl_batch_shrink    = pl_batch_shrink
        self.pl_decay           = pl_decay
        self.pl_no_weight_grad  = pl_no_weight_grad
        self.pl_mean            = torch.zeros([], device=device)
        self.blur_init_sigma    = blur_init_sigma
        self.blur_fade_kimg     = blur_fade_kimg
        self.r1_gamma_init      = r1_gamma_init
        self.r1_gamma_fade_kimg = r1_gamma_fade_kimg
        self.neural_rendering_resolution_initial = neural_rendering_resolution_initial
        self.neural_rendering_resolution_final = neural_rendering_resolution_final
        self.neural_rendering_resolution_fade_kimg = neural_rendering_resolution_fade_kimg
        self.gpc_reg_fade_kimg = gpc_reg_fade_kimg
        self.gpc_reg_prob = gpc_reg_prob
        self.dual_discrimination = dual_discrimination
        self.filter_mode = filter_mode
        self.resample_filter = upfirdn2d.setup_filter([1,3,3,1], device=device)
        self.blur_raw_target = True
        assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1)

    def run_G(self, z, c, swapping_prob, neural_rendering_resolution, update_emas=False):
        if swapping_prob is not None:
            c_swapped = torch.roll(c.clone(), 1, 0)
            c_gen_conditioning = torch.where(torch.rand((c.shape[0], 1), device=c.device) < swapping_prob, c_swapped, c)
        else:
            c_gen_conditioning = torch.zeros_like(c)

        ws = self.G.mapping(z, c_gen_conditioning, update_emas=update_emas)
        if self.style_mixing_prob > 0:
            with torch.autograd.profiler.record_function('style_mixing'):
                cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
        gen_output = self.G.synthesis(ws, c, neural_rendering_resolution=neural_rendering_resolution, update_emas=update_emas)
        return gen_output, ws

    def run_D(self, img, c, blur_sigma=0, blur_sigma_raw=0, update_emas=False):
        blur_size = np.floor(blur_sigma * 3)
        if blur_size > 0:
            with torch.autograd.profiler.record_function('blur'):
                f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div(blur_sigma).square().neg().exp2()
                img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum())

        if self.augment_pipe is not None:
            augmented_pair = self.augment_pipe(torch.cat([img['image'],
                                                    torch.nn.functional.interpolate(img['image_raw'], size=img['image'].shape[2:], mode='bilinear', antialias=True)],
                                                    dim=1))
            img['image'] = augmented_pair[:, :img['image'].shape[1]]
            img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], size=img['image_raw'].shape[2:], mode='bilinear', antialias=True)

        logits = self.D(img, c, update_emas=update_emas)
        return logits

    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg):
        assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
        if self.G.rendering_kwargs.get('density_reg', 0) == 0:
            phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase)
        if self.r1_gamma == 0:
            phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase)
        blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0
        r1_gamma = self.r1_gamma

        alpha = min(cur_nimg / (self.gpc_reg_fade_kimg * 1e3), 1) if self.gpc_reg_fade_kimg > 0 else 1
        swapping_prob = (1 - alpha) * 1 + alpha * self.gpc_reg_prob if self.gpc_reg_prob is not None else None

        if self.neural_rendering_resolution_final is not None:
            alpha = min(cur_nimg / (self.neural_rendering_resolution_fade_kimg * 1e3), 1)
            neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * (1 - alpha) + self.neural_rendering_resolution_final * alpha))
        else:
            neural_rendering_resolution = self.neural_rendering_resolution_initial

        real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=self.resample_filter, filter_mode=self.filter_mode)

        if self.blur_raw_target:
            blur_size = np.floor(blur_sigma * 3)
            if blur_size > 0:
                f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div(blur_sigma).square().neg().exp2()
                real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum())

        real_img = {'image': real_img, 'image_raw': real_img_raw}

        # run PTI to get w/w_plus latent codes for real images
        # print(real_img.shape, real_c.shape, gen_z.shape, gen_c.shape)
        # torch.Size([8, 3, 512, 512]) torch.Size([8, 25]) torch.Size([8, 512]) torch.Size([8, 25])
        # convert gen_z to real_z

        batch_size = real_img['image'].shape[0]
        real_z = []
        for i in range(batch_size):
            cur_img = real_img['image'][i]
            cur_c = real_c[i]
            cur_z = pti_projector(self.G, cur_c, cur_img, device=self.device)
            real_z.append(cur_z)
        real_z = torch.stack(real_z)
        print('real_z', real_z.shape)
        

        # Gmain: Maximize logits for generated images.
        if phase in ['Gmain', 'Gboth']:
            with torch.autograd.profiler.record_function('Gmain_forward'):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution)
                gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                loss_Gmain = torch.nn.functional.softplus(-gen_logits)
                training_stats.report('Loss/G/loss', loss_Gmain)
            with torch.autograd.profiler.record_function('Gmain_backward'):
                loss_Gmain.mean().mul(gain).backward()

        # Density Regularization
        if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'l1':
            if swapping_prob is not None:
                c_swapped = torch.roll(gen_c.clone(), 1, 0)
                c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
            else:
                c_gen_conditioning = torch.zeros_like(gen_c)

            ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
            if self.style_mixing_prob > 0:
                with torch.autograd.profiler.record_function('style_mixing'):
                    cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                    cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                    ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
            initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
            perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * self.G.rendering_kwargs['density_reg_p_dist']
            all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
            sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
            sigma_initial = sigma[:, :sigma.shape[1]//2]
            sigma_perturbed = sigma[:, sigma.shape[1]//2:]

            TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
            TVloss.mul(gain).backward()

        # Alternative density regularization
        if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-detach':
            if swapping_prob is not None:
                c_swapped = torch.roll(gen_c.clone(), 1, 0)
                c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
            else:
                c_gen_conditioning = torch.zeros_like(gen_c)

            ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)

            initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front

            perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind
            all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
            sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
            sigma_initial = sigma[:, :sigma.shape[1]//2]
            sigma_perturbed = sigma[:, sigma.shape[1]//2:]

            monotonic_loss = torch.relu(sigma_initial.detach() - sigma_perturbed).mean() * 10
            monotonic_loss.mul(gain).backward()


            if swapping_prob is not None:
                c_swapped = torch.roll(gen_c.clone(), 1, 0)
                c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
            else:
                c_gen_conditioning = torch.zeros_like(gen_c)

            ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
            if self.style_mixing_prob > 0:
                with torch.autograd.profiler.record_function('style_mixing'):
                    cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                    cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                    ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
            initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
            perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp']
            all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
            sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
            sigma_initial = sigma[:, :sigma.shape[1]//2]
            sigma_perturbed = sigma[:, sigma.shape[1]//2:]

            TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
            TVloss.mul(gain).backward()

        # Alternative density regularization
        if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-fixed':
            if swapping_prob is not None:
                c_swapped = torch.roll(gen_c.clone(), 1, 0)
                c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
            else:
                c_gen_conditioning = torch.zeros_like(gen_c)

            ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)

            initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front

            perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind
            all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
            sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
            sigma_initial = sigma[:, :sigma.shape[1]//2]
            sigma_perturbed = sigma[:, sigma.shape[1]//2:]

            monotonic_loss = torch.relu(sigma_initial - sigma_perturbed).mean() * 10
            monotonic_loss.mul(gain).backward()


            if swapping_prob is not None:
                c_swapped = torch.roll(gen_c.clone(), 1, 0)
                c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
            else:
                c_gen_conditioning = torch.zeros_like(gen_c)

            ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
            if self.style_mixing_prob > 0:
                with torch.autograd.profiler.record_function('style_mixing'):
                    cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                    cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                    ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
            initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
            perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp']
            all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
            sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
            sigma_initial = sigma[:, :sigma.shape[1]//2]
            sigma_perturbed = sigma[:, sigma.shape[1]//2:]

            TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
            TVloss.mul(gain).backward()

        # Dmain: Minimize logits for generated images.
        loss_Dgen = 0
        if phase in ['Dmain', 'Dboth']:
            with torch.autograd.profiler.record_function('Dgen_forward'):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution, update_emas=True)
                gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                loss_Dgen = torch.nn.functional.softplus(gen_logits)
            with torch.autograd.profiler.record_function('Dgen_backward'):
                loss_Dgen.mean().mul(gain).backward()

        # Dmain: Maximize logits for real images.
        # Dr1: Apply R1 regularization.
        if phase in ['Dmain', 'Dreg', 'Dboth']:
            name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1'
            with torch.autograd.profiler.record_function(name + '_forward'):
                real_img_tmp_image = real_img['image'].detach().requires_grad_(phase in ['Dreg', 'Dboth'])
                real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(phase in ['Dreg', 'Dboth'])
                real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw}

                real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
                training_stats.report('Loss/scores/real', real_logits)
                training_stats.report('Loss/signs/real', real_logits.sign())

                loss_Dreal = 0
                if phase in ['Dmain', 'Dboth']:
                    loss_Dreal = torch.nn.functional.softplus(-real_logits)
                    training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)

                loss_Dr1 = 0
                if phase in ['Dreg', 'Dboth']:
                    if self.dual_discrimination:
                        with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
                            r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image'], real_img_tmp['image_raw']], create_graph=True, only_inputs=True)
                            r1_grads_image = r1_grads[0]
                            r1_grads_image_raw = r1_grads[1]
                        r1_penalty = r1_grads_image.square().sum([1,2,3]) + r1_grads_image_raw.square().sum([1,2,3])
                    else: # single discrimination
                        with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
                            r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image']], create_graph=True, only_inputs=True)
                            r1_grads_image = r1_grads[0]
                        r1_penalty = r1_grads_image.square().sum([1,2,3])
                    loss_Dr1 = r1_penalty * (r1_gamma / 2)
                    training_stats.report('Loss/r1_penalty', r1_penalty)
                    training_stats.report('Loss/D/reg', loss_Dr1)

            with torch.autograd.profiler.record_function(name + '_backward'):
                (loss_Dreal + loss_Dr1).mean().mul(gain).backward()

#----------------------------------------------------------------------------

and here is the modified projector scripts:

w_plus_projector.py
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

"""Project given image to the latent space of pretrained network pickle."""

import copy
import os
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import dnnlib
import PIL
from camera_utils import LookAtPoseSampler

def project(
        G,
        c,
        # outdir,
        target: torch.Tensor,  # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
        *,
        num_steps=1000,
        w_avg_samples=10000,
        initial_learning_rate=0.01,
        initial_noise_factor=0.05,
        lr_rampdown_length=0.25,
        lr_rampup_length=0.05,
        noise_ramp_length=0.75,
        regularize_noise_weight=1e5,
        verbose=False,
        device: torch.device,
        initial_w=None,
        image_log_step=100,
        # w_name: str
):
    # os.makedirs(f'{outdir}/{w_name}_w_plus', exist_ok=True)
    # outdir = f'{outdir}/{w_name}_w_plus'
    assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)

    def logprint(*args):
        if verbose:
            print(*args)

    G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore

    # Compute w stats.
    w_avg_path = './w_avg.npy'
    w_std_path = './w_std.npy'
    if (not os.path.exists(w_avg_path)) or (not os.path.exists(w_std_path)):
        print(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
        z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
        # c_samples = c.repeat(w_avg_samples, 1)

        # use avg look at point

        camera_lookat_point = torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device)
        cam2world_pose = LookAtPoseSampler.sample(3.14 / 2, 3.14 / 2, camera_lookat_point,
                                                  radius=G.rendering_kwargs['avg_camera_radius'], device=device)
        focal_length = 4.2647  # FFHQ's FOV
        intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
        c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
        c_samples = c_samples.repeat(w_avg_samples, 1)

        w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples)  # [N, L, C]
        w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)  # [N, 1, C]
        w_avg = np.mean(w_samples, axis=0, keepdims=True)  # [1, 1, C]
        # print('save w_avg  to ./w_avg.npy')
        # np.save('./w_avg.npy',w_avg)
        w_avg_tensor = torch.from_numpy(w_avg).cuda()
        w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5

        # np.save(w_avg_path, w_avg)
        # np.save(w_std_path, w_std)
    else:
        # w_avg = np.load(w_avg_path)
        # w_std = np.load(w_std_path)
        raise Exception(' ')

    # z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
    # c_samples = c.repeat(w_avg_samples, 1)
    # w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples)  # [N, L, C]
    # w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)  # [N, 1, C]
    # w_avg = np.mean(w_samples, axis=0, keepdims=True)  # [1, 1, C]
    # w_avg_tensor = torch.from_numpy(w_avg).cuda()
    # w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5

    start_w = initial_w if initial_w is not None else w_avg

    # Setup noise inputs.
    noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name}

    # Load VGG16 feature detector.
    url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
    # url = './networks/vgg16.pt'
    with dnnlib.util.open_url(url) as f:
        vgg16 = torch.jit.load(f, map_location=device).eval().to(device)

    # Features for target image.
    target_images = target.unsqueeze(0).to(device).to(torch.float32)
    if target_images.shape[2] > 256:
        target_images = F.interpolate(target_images, size=(256, 256), mode='area')
    target_features = vgg16(target_images, resize_images=False, return_lpips=True)

    start_w = np.repeat(start_w, G.backbone.mapping.num_ws, axis=1)
    w_opt = torch.tensor(start_w, dtype=torch.float32, device=device,
                         requires_grad=True)  # pylint: disable=not-callable

    optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999),
                                 lr=0.1)

    # Init noise.
    for buf in noise_bufs.values():
        buf[:] = torch.randn_like(buf)
        buf.requires_grad = True

    for step in tqdm(range(num_steps)):

        # Learning rate schedule.
        t = step / num_steps
        w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
        lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
        lr = initial_learning_rate * lr_ramp
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Synth images from opt_w.
        w_noise = torch.randn_like(w_opt) * w_noise_scale
        ws = (w_opt + w_noise)
        synth_images = G.synthesis(ws,c, noise_mode='const')['image']

        # if step % image_log_step == 0:
        #     with torch.no_grad():
        #         vis_img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)

        #         PIL.Image.fromarray(vis_img[0].cpu().numpy(), 'RGB').save(f'{outdir}/{step}.png')

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        synth_images = (synth_images + 1) * (255 / 2)
        if synth_images.shape[2] > 256:
            synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')

        # Features for synth images.
        synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
        dist = (target_features - synth_features).square().sum()

        # Noise regularization.
        reg_loss = 0.0
        for v in noise_bufs.values():
            noise = v[None, None, :, :]  # must be [1,1,H,W] for F.avg_pool2d()
            while True:
                reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
                reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
                if noise.shape[2] <= 8:
                    break
                noise = F.avg_pool2d(noise, kernel_size=2)
        loss = dist + reg_loss * regularize_noise_weight

        # if step % 10 == 0:
        #     with torch.no_grad():
        #         print({f'step {step}, first projection _{w_name}': loss.detach().cpu()})

        # Step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')

        # Normalize noise.
        with torch.no_grad():
            for buf in noise_bufs.values():
                buf -= buf.mean()
                buf *= buf.square().mean().rsqrt()

    del G
    return w_opt

I got errors as shown below:

Computing W midpoint and stddev using 600 samples...
  0%|          | 0/1000 [00:00<?, ?it/s]/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/nn/modules/module.py:1488: UserWarning: operator() profile_node %106 : int = prim::profile_ivalue(%104)
 does not have profile information (Triggered internally at /opt/conda/conda-bld/pytorch_1674202356920/work/torch/csrc/jit/codegen/cuda/graph_fuser.cpp:105.)
  return forward_call(*args, **kwargs)
  0%|          | 0/1000 [00:07<?, ?it/s]
Traceback (most recent call last):
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 396, in <module>
    main() # pylint: disable=no-value-for-parameter
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 391, in main
    launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 101, in launch_training
    subprocess_fn(rank=0, c=c, temp_dir=temp_dir)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 52, in subprocess_fn
    training_loop.training_loop(rank=rank, **c)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/training_loop.py", line 286, in training_loop
    loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/loss.py", line 156, in accumulate_gradients
    cur_z = pti_projector(self.G, cur_c, cur_img, device=self.device)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/loss.py", line 49, in pti_projector
    w = w_plus_projector.project(G, c, id_image, device=device, w_avg_samples=600)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/projector/w_plus_projector.py", line 171, in project
    loss.backward()
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/function.py", line 275, in apply
    return user_fn(self, *args)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/torch_utils/ops/conv2d_gradfix.py", line 146, in backward
    grad_weight = Conv2dGradWeight.apply(grad_output, input, weight)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/function.py", line 508, in apply
    return super().apply(*args, **kwargs)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/torch_utils/ops/conv2d_gradfix.py", line 173, in forward
    return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1]
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/_ops.py", line 499, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA__convolution_backward)

seems like the error happens at loss.backward() and I checked most of the variables/loss/model to make sure they are on cuda:0. but I still got no luck to solve this. Do you know how to make loss backpropagate properly?

every image need 127M checkpoint

Well done. Your work is amazing. The performance is the best as far as I know. But I found inference time is slow, and every image should save 127M pt file. It is not pratical. Some advise is welcom.

Unexpected size of W features in generated_images = self.G.synthesis(w, noise_mode='const', force_fp32=True)

I am using a self-trained model trained using the StyleGan-ada pytorch repository.
While using use_multi_id_training=True I get a size mismatch error in the forward call of G.synthesis.

The full trace is shown:

Traceback (most recent call last):
File "test.py", line 6, in
run_PTI(use_multi_id_training=True)
File "/home/usman/Documents/Work/PTI/scripts/run_pti.py", line 44, in run_PTI
coach.train()
File "/home/usman/Documents/Work/PTI/training/coaches/multi_id_coach.py", line 58, in train
generated_images = self.forward(w_pivot)
File "/home/usman/Documents/Work/PTI/training/coaches/base_coach.py", line 130, in forward
generated_images = self.G.synthesis(w, noise_mode='const', force_fp32=True)
File "/home/usman/anaconda3/envs/eg3d/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "", line 460, in forward
File "/home/usman/Documents/Work/PTI/torch_utils/misc.py", line 93, in assert_shape
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
AssertionError: Wrong size for dimension 1: got 18, expected 16

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.