Giter Club home page Giter Club logo

partialconv's Introduction

Partial Convolution Layer for Padding and Image Inpainting

This is the PyTorch implementation of partial convolution layer. It can serve as a new padding scheme; it can also be used for image inpainting.

Partial Convolution based Padding
Guilin Liu, Kevin J. Shih, Ting-Chun Wang, Fitsum A. Reda, Karan Sapra, Zhiding Yu, Andrew Tao, Bryan Catanzaro
NVIDIA Corporation
Technical Report (Technical Report) 2018

Image Inpainting for Irregular Holes Using Partial Convolutions
Guilin Liu, Fitsum A. Reda, Kevin J. Shih, Ting-Chun Wang, Andrew Tao, Bryan Catanzaro
NVIDIA Corporation
In The European Conference on Computer Vision (ECCV) 2018

Comparison with Zero Padding

Installation

Installation can be found: https://github.com/pytorch/examples/tree/master/imagenet

Usage:

  • using partial conv for padding
#typical convolution layer with zero padding
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)

#partial convolution based padding
PartialConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
  • using partial conv for image inpainting, set both multi_channel and return_mask to be True
#partial convolution for inpainting (using multiple channels and updating mask)
PartialConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, multi_channel=True, return_mask=True)

Mixed Precision Training with AMP for image inpainting

  • Installation: to train with mixed precision support, please first install apex from: https://github.com/NVIDIA/apex
  • Required change #1 (Typical changes): typical changes needed for AMP
  from apex import amp
  
  #initializing model and optimizer
  self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=args.amp_opt_level)
  
  #initializing vgg loss function/extractor
  self.vgg_feat_loss = amp.initialize(self.vgg_feat_loss, opt_level=args.amp_opt_level)
  
  #scale loss
  with amp.scale_loss(total_loss, self.g_optimizer) as scaled_loss:
      scaled_loss.backward()

  • Required change #2 (Gram Matrix Loss): in Gram matrix loss computation, change one-step division to two-step smaller divisions
    input = torch.zeros(b, ch, ch).type(features.type())
    gram = torch.baddbmm(input, features, features_t, beta=0, alpha=1./(ch * h * w), out=None)
  • Required change #3 (Small Constant Number): make the small constant number a bit larger (e.g. 1e-8 to 1e-6)
    • change from 1e-8: self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
    • to 1e-6: self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-6)

Usage of partial conv based padding to train ImageNet

  • ResNet50 using zero padding (default padding)
python main.py -a resnet50 --data_train /path/ILSVRC/Data/CLS-LOC/train --data_val /path/ILSVRC/Data/CLS-LOC/perfolder_val --batch-size 192 --workers 32 --prefix multigpu_b192 --ckptdirprefix experiment_1/
  • ResNet50 using partial conv based padding
python main.py -a pdresnet50 --data_train /path/ILSVRC/Data/CLS-LOC/train --data_val /path/ILSVRC/Data/CLS-LOC/perfolder_val --batch-size 192 --workers 32 --prefix multigpu_b192 --ckptdirprefix experiment_1/
  • vgg16_bn using zero padding (default padding)
python main.py -a vgg16_bn --data_train /path/ILSVRC/Data/CLS-LOC/train --data_val /path/ILSVRC/Data/CLS-LOC/perfolder_val --batch-size 192 --workers 32 --prefix multigpu_b192 --ckptdirprefix experiment_1/
  • vgg16_bn using partial conv based padding
python main.py -a pdvgg16_bn --data_train /path/ILSVRC/Data/CLS-LOC/train --data_val /path/ILSVRC/Data/CLS-LOC/perfolder_val --batch-size 192 --workers 32 --prefix multigpu_b192 --ckptdirprefix experiment_1/

Pretrained checkpoints (weights) for VGG and ResNet networks with partial convolution based padding:

https://www.dropbox.com/sh/t6flbuoipyzqid8/AACJ8rtrF6V5b9348aG5PIhia?dl=0

Comparison with Zero Padding, Reflection Padding and Replication Padding for 5 runs

The best top-1 accuracies for each run with 1-crop testing. *_zero, *_pd, *_ref and *_rep indicate the corresponding model with zero padding, partial convolution based padding, reflection padding and replication padding respectively. *_best means the best validation score for each run of the training. Average represents the average accuracy of the 5 runs. Column diff represents the difference with corresponding network using zero padding. Column stdev represents the standard deviation of the accuracies from 5 runs. PT_official represents the corresponding official accuracies published on PyTorch website: https://pytorch.org/docs/stable/torchvision/models.html

Citation

@inproceedings{liu2018partialpadding,
   author    = {Guilin Liu and Kevin J. Shih and Ting-Chun Wang and Fitsum A. Reda and Karan Sapra and Zhiding Yu and Andrew Tao and Bryan Catanzaro},
   title     = {Partial Convolution based Padding},
   booktitle = {arXiv preprint arXiv:1811.11718},   
   year      = {2018},
}
@inproceedings{liu2018partialinpainting,
   author    = {Guilin Liu and Fitsum A. Reda and Kevin J. Shih and Ting-Chun Wang and Andrew Tao and Bryan Catanzaro},
   title     = {Image Inpainting for Irregular Holes Using Partial Convolutions},
   booktitle = {The European Conference on Computer Vision (ECCV)},   
   year      = {2018},
}

Contact: Guilin Liu ([email protected])

Acknowledgments

We thank Jinwei Gu, Matthieu Le, Andrzej Sulecki, Marek Kolodziej and Hongfu Liu for helpful discussions.

partialconv's People

Contributors

liuguilin1225 avatar ruthcfong avatar fitsumreda avatar bryancatanzaro avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.