Giter Club home page Giter Club logo

portrait-matting-unet-flask's Introduction

Portrait Mating implementation in UNet with PyTorch.

Segmentation Demo Result: Segmentation Matting Demo Result: Matting For the convenience of demonstration, I built the API service through Flask, and finally deployed it on WeChat Mini Program. The code part of the WeChat applet is in here portrait-matting-wechat.

Dependencies

  • Python 3.6
  • PyTorch >= 1.1.0
  • Torchvision >= 0.3.0
  • Flask 1.1.1
  • future 0.18.2
  • matplotlib 3.1.3
  • numpy 1.16.0
  • Pillow 6.2.0
  • protobuf 3.11.3
  • tensorboard 1.14.0
  • tqdm==4.42.1

Data

This model was trained from scratch with 18000 images (data augmentation by 2000images) Training dataset was from Deep Automatic Portrait Matting. Your can download in baidu cloud http://pan.baidu.com/s/1dE14537. Password: ndg8 For academic communication only, if there is a quote, please inform the original author!

We augment the number of images by perturbing them withrotation and scaling. Four rotation angles{−45◦,−22◦,22◦,45◦}and four scales{0.6,0.8,1.2,1.5}are used. We also apply four different Gamma transforms toincrease color variation. The Gamma values are{0.5,0.8,1.2,1.5}. After thesetransforms, we have 18K training images.

Run locally

Note : Use Python 3

Prediction

You can easily test the output masks on your images via the CLI.

To predict a single image and save it:

$ python predict.py -i image.jpg -o output.jpg

To predict a multiple images and show them without saving them:

$ python predict.py -i image1.jpg image2.jpg --viz --no-save
> python predict.py -h
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
                  [--output INPUT [INPUT ...]] [--viz] [--no-save]
                  [--mask-threshold MASK_THRESHOLD] [--scale SCALE]

Predict masks from input images

optional arguments:
  -h, --help            show this help message and exit
  --model FILE, -m FILE
                        Specify the file in which the model is stored
                        (default: MODEL.pth)
  --input INPUT [INPUT ...], -i INPUT [INPUT ...]
                        filenames of input images (default: None)
  --output INPUT [INPUT ...], -o INPUT [INPUT ...]
                        Filenames of ouput images (default: None)
  --viz, -v             Visualize the images as they are processed (default:
                        False)
  --no-save, -n         Do not save the output masks (default: False)
  --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
                        Minimum probability value to consider a mask pixel
                        white (default: 0.5)
  --scale SCALE, -s SCALE
                        Scale factor for the input images (default: 0.5)

You can specify which model file to use with --model MODEL.pth.

Training

> python train.py -h
usage: train.py [-h] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL]

Train the UNet on images and target masks

optional arguments:
  -h, --help            show this help message and exit
  -e E, --epochs E      Number of epochs (default: 5)
  -b [B], --batch-size [B]
                        Batch size (default: 1)
  -l [LR], --learning-rate [LR]
                        Learning rate (default: 0.1)
  -f LOAD, --load LOAD  Load model from a .pth file (default: False)
  -s SCALE, --scale SCALE
                        Downscaling factor of the images (default: 0.5)
  -v VAL, --validation VAL
                        Percent of the data that is used as validation (0-100)
                        (default: 10.0)

By default, the scale is 0.5, so if you wish to obtain better results (but use more memory), set it to 1.

The input images and target masks should be in the data/imgs and data/masks folders respectively.

Start API service

$ python app.py

Then you can use the model through the API

Run on server

  1. Install virtual environment
  2. Install gunicorn in a virtual environment
  3. Proxy through nginx

Notes on memory

$ python train.py -e 200 -b 1 -l 0.1 -s 0.5 -v 15.0

The model has be trained from scratch on a RTX2080Ti 11GB. 18,000 training dataset, running for 4 days +

Thanks

The birth of this project is inseparable from the following projects:

  • Flask:The Python micro framework for building web applications
  • Pytorch-UNet:PyTorch implementation of the U-Net for image semantic segmentation with high quality images

portrait-matting-unet-flask's People

Contributors

dependabot[bot] avatar leijue222 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

portrait-matting-unet-flask's Issues

online augmentation

replace util/dataset.py to

from os.path import splitext
from os import listdir
import numpy as np
from glob import glob
import torch
from torch.utils.data import Dataset
import logging
from PIL import Image
import imgaug as ia
from imgaug import augmenters as iaa

class BasicDataset(Dataset):
    def __init__(self, imgs_dir, masks_dir, scale=1):
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.scale = scale
        assert 0 < scale <= 1, 'Scale must be between 0 and 1'

        self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
                    if not file.startswith('.')]
        #self.ids = self.ids[:20]
        logging.info(f'Creating dataset with {len(self.ids)} examples')

    def __len__(self):
        return len(self.ids)

    @classmethod
    def preprocess(cls, pil_img, scale):
        w, h = pil_img.size
        newW, newH = int(scale * w), int(scale * h)
        assert newW > 0 and newH > 0, 'Scale is too small'
        pil_img = pil_img.resize((newW, newH))

        img_nd = np.array(pil_img)

        if len(img_nd.shape) == 2:
            img_nd = np.expand_dims(img_nd, axis=2)
        # HWC to CHW
        img_trans = img_nd.transpose((2, 0, 1))
        if img_trans.max() > 1:
            img_trans = img_trans / 255

        return img_trans

    def preprocess_all(self, img, mask, scale):
        w, h = img.size
        newW, newH = int(scale * w), int(scale * h)
        assert newW > 0 and newH > 0, 'Scale is too small'
        pil_img = img.resize((newW, newH))
        pil_mask = mask.resize((newW, newH))
        img_nd = np.array(pil_img)
        mask_nd = np.array(pil_mask)
        seq = iaa.Sequential([
            iaa.Sometimes(0.5,iaa.Crop(px=(0,16))),
            iaa.Affine(rotate=(-90,90)),
            iaa.Sometimes(0.5,iaa.Fliplr(0.5)),
            iaa.Sometimes(0.5,iaa.GaussianBlur((0, 0.5)),
            iaa.Sometimes(0.5,iaa.AdditiveGaussianNoise(loc=0,scale=(0.0,0.05*255),per_channel=0.5)),
            random_state=True)
        ])
        seg_map = ia.SegmentationMapsOnImage(mask_nd, shape = img_nd.shape)
        image_aug, seg_aug = seq(image=img_nd, segmentation_maps = seg_map)
        seg_map = seg_aug.get_arr()
        img_trans = image_aug.transpose((2, 0, 1))
        if img_trans.max() > 1:
            img_trans = img_trans / 255
        seg_map = np.expand_dims(seg_map, axis=2)
        seg_trans = seg_map.transpose((2, 0, 1))
        if seg_trans.max() > 1:
            seg_trans = seg_trans / 255

        return img_trans, seg_trans


    def __getitem__(self, i):
        idx = self.ids[i]
        mask_file = glob(self.masks_dir + idx + '*')
        img_file = glob(self.imgs_dir + idx + '*')

        assert len(mask_file) == 1, \
            f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
        assert len(img_file) == 1, \
            f'Either no image or multiple images found for the ID {idx}: {img_file}'
        mask = Image.open(mask_file[0])
        img = Image.open(img_file[0])

        assert img.size == mask.size, \
            f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
        # img = self.preprocess(img, self.scale)
        # mask = self.preprocess(mask, self.scale)
        img, mask = self.preprocess_all(img, mask, self.scale)

        return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.