Giter Club home page Giter Club logo

faceswapper's Introduction

FaceSwapper - Official PyTorch Implementation

FaceSwapper: Learning Disentangled Representation for One-shot Progressive Face Swapping
Qi Li, Weining Wang, Chengzhong Xu, Zhenan Sun
In arxiv 2022.

Paper: https://arxiv.org/abs/2203.12985/

Although face swapping has attracted much attention in recent years, it remains a challenging problem. The existing methods leverage a large number of data samples to explore the intrinsic properties of face swapping without taking into account the semantic information of face images. Moreover, the representation of the identity information tends to be fixed, leading to suboptimal face swapping. In this paper, we present a simple yet efficient method named FaceSwapper, for one-shot face swapping based on Generative Adversarial Networks. Our method consists of a disentangled representation module and a semantic-guided fusion module. The disentangled representation module is composed of an attribute encoder and an identity encoder, which aims to achieve the disentanglement of the identity and the attribute information. The identity encoder is more flexible and the attribute encoder contains more details of the attributes than its competitors. Benefiting from the disentangled representation, FaceSwapper can swap face images progressively. In addition, semantic information is introduced into the semantic-guided fusion module to control the swapped area and model the pose and expression more accurately. The experimental results show that our method achieves state-of-the-art results on benchmark datasets with fewer training samples.

Environment

Clone this repository:

git clone https://github.com/liqi-casia/FaceSwapper.git
cd FaceSwapper/

Install the dependencies:

conda create -n faceswapper python=3.6.7
conda activate faceswapper
conda install -y pytorch=1.4.0 torchvision=0.5.0 cudatoolkit=10.0 -c pytorch
conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge
pip install opencv-python==4.1.2.30 ffmpeg-python==0.2.0 scikit-image==0.16.2
pip install pillow==7.0.0 scipy==1.2.1 tqdm==4.43.0 munch==2.5.0
conda install -y -c anaconda pyyaml
pip install tensorboard tensorboardX

Datasets and pre-trained checkpoints

We provide a link to download datasets used in FaceSwapper and the corresponding pre-trained checkpoints. The datasets and checkpoints should be moved to the data and expr/checkpoints directories, respectively.

After storing all the files, the directory structure of ./data and ./pretrained_checkpoints is expected as follows.

./data
├── CelebA Dataset
│   ├── CalebA images
│   ├── CelebA landmark images
│   └── CelebA mask images
└── FF++ Dataset
    ├── ff++ images
    ├── ff++ landmark images
    ├── ff++ mask images
    └── ff++ parsing images

./pretrained_checkpoints
├── model_ir_se50.pth
├── wing.ckpt
└── faceswapper.ckpt

Generating swapped images

After downloading the pre-trained checkpoints, you can synthesize swapped images. The following commands will save generated images to the expr/results directory.

FaceForensics++ Dataset. To generate swapped images, you need to specify the testing parameters in param.yaml (especially paramerters in #directory for testing ). Then run the following command:

python main.py 

There are three subfolders in expr/results/ff++/, which are named swapped_result_single, swapped_result_afterps and swapped_result_all. Each image is named as source_FS_target.png, where source image provides the identity information and target image provides attribute information.

--swapped_result_single: the swapped images.

--swapped_result_afterps: the swapped images after post process.

--swapped_result_all: caoncatenation of the souce images, the target images, the swapped images and the swapped images after post process.

Other Datasets. First, crop and align face images from other datasets automatically so that the proportion of face occupied in the whole is similar to that of CelebA dataset and FaceForensics++ dataset. Then, define the face swapping list siimilar to face_swap_list.txt (source_image_name target_image_name). The other testing procedure is similar to FaceForensics++ Dataset.

Post Process. If occlusion exists in the source image (e.g., hair, hat), we simply preserve the forehead and hair of the target image in the swapped image. Othewise, we simplely preserve the hair of the target image. You just need to Set post_process: True if you want the post process.

Training networks

To train FaceSwapper from scratch, just set the training parameters in param.yaml, and run the following commands. Generated images and network checkpoints will be stored in the expr/samples and expr/checkpoints directories, respectively. Training usually takes about several days on a single Tesla V100 GPU depending on the total trainig iterations.

python main.py 

License

The source code, pre-trained models, and dataset are available under Creative Commons BY-NC 4.0 license. You can use, copy, tranform and build upon the material for non-commercial purposes as long as you give appropriate credit by citing our paper, and indicate if changes were made. For technical, business and other inquiries, please contact [email protected].

Citation

If you find this work useful for your research, please cite our paper:

@misc{li2022learning,
      title={Learning Disentangled Representation for One-shot Progressive Face Swapping}, 
      author={Qi Li and Weining Wang and Chengzhong Xu and Zhenan Sun},
      year={2022},
      eprint={2203.12985},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgements

The code is written based on the following projects. We would like to thank for their contributions.

faceswapper's People

Contributors

liqi-casia avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

faceswapper's Issues

pre-trained models

Hello,

Can you share the pre-trained models on Google Drive? It seems impossible to download them from baidu.

thanks!

Problem about inferencing on custom data

Hi, there, thank for your great work!
I write a simple demo about inferencing on custom data without generate landmark or mask images, the only difference is that i implement a function named get_mark_parsing to generate landmark and parse images. But i got a bad result with original inference. If u have time, plz give me some suggestion about where im wrong, Thx.
Full code as following:

from core.model import build_model
from core import utils
from core.checkpoint import CheckpointIO
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import torchvision.utils as vutils
import numpy as np


from parse import BiSeNet
import torch
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
net.load_state_dict(torch.load('./pretrained_checkpoints/79999_iter.pth'))
net.eval()


config = dict(img_size=256,id_dim=512,wing_path='pretrained_checkpoints/wing.ckpt')
nets, nets_ema = build_model(config)
for name, module in nets.items():
    nets[name].to('cuda')
    if ('ema' not in name) and ('fan' not in name) and ('arcface' not in name):
        print('Initializing %s...' % name)
        module.apply(utils.he_init)
for name, module in nets_ema.items():
    nets_ema[name].to('cuda')
ckptios = [CheckpointIO('pretrained_checkpoints/', **nets_ema)]
for ckptio in ckptios:
    ckptio.load_test('faceswapper.ckpt')
    
    
src_dir = 'data/FF++_Dataset/ff++/raw_003_0.png'
tar_dir = src_dir.replace('raw_003_0.png', 'raw_000_0.png')

to_tensor = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

img_size = 256
transform = transforms.Compose([
        transforms.Resize([img_size, img_size]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.5, 0.5, 0.5]),
])
transform_seg = transforms.Compose([
        transforms.Resize([img_size, img_size]),
        transforms.ToTensor(),
])

def get_mark_parsing(src_img):
    image = src_img.resize((512, 512), Image.BILINEAR)
    image = to_tensor(image)
    image = torch.unsqueeze(image, 0).cuda()
    out = net(image)[0]
    
    parsing = out.squeeze(0).cpu().detach().numpy().argmax(0)
    
    mark = parsing.copy()
    index = np.where(((parsing>0) & (parsing<=6)) | ((parsing>= 10) & (parsing<=13)))
    mark[index[0], index[1]] = 255
    index = np.where(mark!=255)
    mark[index[0], index[1]] = 0
    mark = Image.fromarray(np.uint8(mark))
    
    parsing = Image.fromarray(np.uint8(parsing))
    return mark, parsing

source_image = Image.open(src_dir).convert('RGB')
source_mask_image, source_parsing = get_mark_parsing(source_image)

tar_image = Image.open(tar_dir).convert('RGB')
tar_mask_image, tar_parsing = get_mark_parsing(tar_image)

source_lm_image = transform(source_image)
source_lm_image = nets_ema.fan.get_landmark_curve(source_lm_image.to('cuda').unsqueeze(0))
source_lm_image = source_lm_image[0]
source_lm_image = Image.fromarray(source_lm_image)

tar_lm_image = transform(tar_image)
tar_lm_image = nets_ema.fan.get_landmark_curve(tar_lm_image.to('cuda').unsqueeze(0))
tar_lm_image = tar_lm_image[0]
tar_lm_image = Image.fromarray(tar_lm_image)

source_image = transform(source_image)
source_lm_image = transform(source_lm_image)
source_mask_image = transform_seg(source_mask_image)
source_parsing = transform_seg(source_parsing)

tar_image = transform(tar_image)
tar_lm_image = transform(tar_lm_image)
tar_mask_image = transform_seg(tar_mask_image)
tar_parsing_image = transform_seg(tar_parsing)

src=source_image.to('cuda')
tar=tar_image.to('cuda')
src_lm=source_lm_image.to('cuda')
tar_lm=tar_lm_image.to('cuda')
src_mask=(1-source_mask_image).to('cuda')
tar_mask=(1-tar_mask_image).to('cuda')
src_parsing=source_parsing.to('cuda')
tar_parsing=tar_parsing_image.to('cuda')

srcid_taratt, tarid_srcatt,_,_,_,_ = nets_ema.generator(src.unsqueeze(dim=0), 
                                                    tar.unsqueeze(dim=0), 
                                                    src_lm.unsqueeze(dim=0),
                                                    tar_lm.unsqueeze(dim=0), 
                                                    src_mask.unsqueeze(dim=0), 
                                                    tar_mask.unsqueeze(dim=0))
def denormalize(x):
    out = (x + 1) / 2
    return out.clamp_(0, 1)
x = srcid_taratt[0,:,:,:]
x = denormalize(x)
vutils.save_image(x.cpu(), 'test.jpg', nrow=1, padding=0)

how to fix this

File "main.py", line 22, in main
solver = Solver(config)
File "C:\Users\Green\Desktop\FaceSwapper\core\solver.py", line 74, in init
CheckpointIO(ospj(config['checkpoint_dir'], '{:06d}_nets.ckpt'), **self.nets),
File "C:\Users\Green\Desktop\FaceSwapper\core\checkpoint.py", line 13, in init
os.makedirs(os.path.dirname(fname_template), exist_ok=True)
File "C:\Users\Green\anaconda3\envs\faceswapper\lib\os.py", line 220, in makedirs
mkdir(name, mode)
FileNotFoundError: [WinError 3] The system cannot find the path specified: '{:'

Colab?

I'm excited to watch the demo video.
Please release the code for Google colab so that we can try it easily.
I am looking forward to your next update.
Thank you.

How to deal with other datasets

In your readme: "Other Datasets. First, crop and align face images from other datasets automatically so that the proportion of face occupied in the whole is similar to that of CelebA dataset and FaceForensics++ dataset. "
I am confused that how you crop and align face images, and how can i get similar faces in CelebA dataset and FaceForensics++ dataset.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.