Giter Club home page Giter Club logo

pytorch-slim-cnn's Introduction

This repository has a PyTorch implementation of SlimNet as described in the paper:

Slim-CNN: A Light-Weight CNN for Face Attribute Prediction Ankit Sharma, Hassan Foroosh Paper

Requirements

Python 3.6.x

torch==1.2.0
torchvision==0.4.0

Dataset

The CelebA Facial Recognition Dataset is available here, the default train arguments look for the dataset in the following format in the current working directory:

data/list_attr_celeba.csv
data/list_eval_partition.csv
data/img_align_celeba/000001.jpg
data/img_align_celeba/000002.jpg
...

Alternatively, the argument you pass as --data_dir in the train script should be the path to (relative or absolute) the directory containing 2 CSV files and folder of cropped images.

Input and output options

  --data_dir   STR    Training data folder.      Default is data`.
  --save_dir   STR    Model checkpoints folder.  Default is `checkpoints`.

Model options

  --save_every                  INT     Save frequency                                    Default is 5
  --num_epochs                  INT     Number of epochs.                                 Default is 20.
  --batch_size                  INT     Number of images per batch.                       Default is 64.
  --conv_filters                INT     Number of initial conv filters                    Default is 20.
  --conv_filter_size            INT     Initial conv filter size.                         Default is 7.
  --conv_filter_stride          INT     Initial conv filter stride.                       Default is 2.  
  --filter_counts               INT     List of Filter counts for the Slim modules        Default is 16 32 48 64.
  --depth_multiplier            INT     Depth width for separable depthwise convolution   Default is 1.
  --num_classes                 INT     Number of class labels        .                   Default is 40.  
  --num_workers                 INT     Number of threads for dataloading                 Default is 2.
  --weight_decay                FLOAT   Weight decay of Adam.                             Default is 0.0001.
  --learning_rate               FLOAT   Adam learning rate.                               Default is 0.0001.
  --decay_lr_every              FLOAT   Frequency to decay learning rate                  Default is 0
  --lr_decay                    FLOAT   Factor to decay learning rate by                  Default is 0.1.

Examples

Setup

git clone https://github.com/gtamba/pytorch-slim-cnn & cd pytorch-slim-cnn
pip install -r requirements.txt
mkdir checkpoints

Train a model with the CelebA dataset in the data/ folder

python train.py --num_epochs 6 --save_every 2 --batch_size 256

Sample out of the box inference code

import torch
from slimnet import SlimNet
from torchvision import transforms
from PIL import Image
import numpy as np

PATH_TO_IMAGE = 'data/img_align_celeba/000001.jpg'
labels = np.array(['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
       'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
       'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
       'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
       'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
       'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
       'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
       'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
       'Wearing_Necklace', 'Wearing_Necktie', 'Young'])

# GPU isn't necessary but could definitly speed up, swap the comments to use best hardware available
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device('cpu')

transform = transforms.Compose([
                              transforms.Resize((178,218)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])

# Make tensor and normalize, add pseudo batch dimension and move to configured device
with open(PATH_TO_IMAGE, 'rb') as f:
    x = transform(Image.open(f)).unsqueeze(0).to(device)

model = SlimNet.load_pretrained('models/celeba_20.pth').to(device)

with torch.no_grad():
            model.eval()
            logits = model(x)
            sigmoid_logits = torch.sigmoid(logits)
            predictions = (sigmoid_logits > 0.5).squeeze().numpy()

print(labels[predictions.astype(bool)])

Benchmarks

Model Footprint

  • Model Size ~ 7 mb
  • Number of parameters ~ 0.6 M
Simple timeit benchmarks (10000 loops including I/O) on NVidia K80
  • CPU : 0.1598 seconds per image ~ 6.25 frames per second
  • GPU : 0.0753 seconds per image ~ 13.3 frames per second

Training Metrics

todo, see notebook for now for training metric trends

Notebooks

Notebook to evaluate model on test set as well as plot training metrics

Notes

  • No data augmentation is used right now although it would definitely help in performance/robustness
  • It is unclear if weight_decay on the Optimizer translates well as L2 regularization to the network weights as opposed to manually adding them to the loss
  • Number of epochs and batch size were not specified in the paper, tried 20 epochs with batch size 256 which is conservative at best, but the loss trend shows an overfitting inflection around the 20th epoch mark as can be seen in the notebook
  • torchvision now provides an out of the box Dataset for the CelebA dataset (which will handle downloading the Dataset) so a minor script edit may save you some file download/organizing
  • NVIDIA has not yet implemented Depthwise Separable Convolutions in cuDNN so the theoretical speedup won't be visible... yet

pytorch-slim-cnn's People

Contributors

gtamba avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.