Giter Club home page Giter Club logo

cntk-fully-convolutional-networks's Introduction

cntk_resnet_fcn

This is a CNTK implementation of Fully Convolutional Network, which is a deep learning segmentation method proposed by J. Long et al. The FCN was originally proposed using VGG, but here we use ResNet-18 as the base model.

Example Usage

Tested with cntk-2.1-gpu-python3.5 docker

import numpy as np
import cntk_resnet_fcn
import simulation
%matplotlib inline
import helper

import cntk as C
from cntk.learners import learning_rate_schedule, UnitType

Check some images/masks from simulation

# Generate some random images
input_images, target_masks = simulation.generate_random_data(192, 192, count=3)

print(input_images.shape, target_masks.shape)

# Change channel-order and make 3 channels for matplot
input_images_rgb = [(x.swapaxes(0, 2).swapaxes(0,1).repeat(3, axis=2) * -255 + 255).astype(np.uint8) for x in input_images]

# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]

# Left: Input image, Right: Target mask
helper.plot_side_by_side([input_images_rgb, target_masks_rgb])
(3, 1, 192, 192) (3, 6, 192, 192)

Left: Input image, Right: Target mask

Images and masks from simulation

Prepare the resnet-fcn model

from cntk.device import try_set_default_device, gpu
try_set_default_device(gpu(0))

def slice_minibatch(data_x, data_y, i, minibatch_size):
    sx = data_x[i * minibatch_size:(i + 1) * minibatch_size]
    sy = data_y[i * minibatch_size:(i + 1) * minibatch_size]

    return sx, sy

def measure_error(data_x, data_y, x, y, trainer, minibatch_size):
    errors = []
    for i in range(0, int(len(data_x) / minibatch_size)):
        data_sx, data_sy = slice_minibatch(data_x, data_y, i, minibatch_size)

        errors.append(trainer.test_minibatch({x: data_sx, y: data_sy}))

    return np.mean(errors)

def train(images, masks, use_existing=False):
    shape = input_images[0].shape
    data_size = input_images.shape[0]

    # Split data
    test_portion = int(data_size * 0.1)
    indices = np.random.permutation(data_size)
    test_indices = indices[:test_portion]
    training_indices = indices[test_portion:]

    test_data = (images[test_indices], masks[test_indices])
    training_data = (images[training_indices], masks[training_indices])

    # Create model
    x = C.input_variable(shape)
    y = C.input_variable(masks[0].shape)

    z = cntk_resnet_fcn.create_model(x, masks.shape[1])
    dice_coef = cntk_resnet_fcn.dice_coefficient(z, y)

    # Load the saved model if specified
    checkpoint_file = "cntk-resnet-fcn.dnn"
    if use_existing:
        z.load_model(checkpoint_file)

    # Prepare model and trainer
    lr = learning_rate_schedule(0.0001, UnitType.sample)
    momentum = C.learners.momentum_as_time_constant_schedule(0.9)
    trainer = C.Trainer(z, (-dice_coef, -dice_coef), C.learners.adam(z.parameters, lr=lr, momentum=momentum))

    # Get minibatches of training data and perform model training
    minibatch_size = 8
    num_epochs = 50

    training_errors = []
    test_errors = []

    for e in range(0, num_epochs):
        for i in range(0, int(len(training_data[0]) / minibatch_size)):
            data_x, data_y = slice_minibatch(training_data[0], training_data[1], i, minibatch_size)

            trainer.train_minibatch({x: data_x, y: data_y})

        # Measure training error
        training_error = measure_error(training_data[0], training_data[1], x, y, trainer, minibatch_size)
        training_errors.append(training_error)

        # Measure test error
        test_error = measure_error(test_data[0], test_data[1], x, y, trainer, minibatch_size)
        test_errors.append(test_error)

        print("epoch #{}: training_error={}, test_error={}".format(e, training_errors[-1], test_errors[-1]))

        trainer.save_checkpoint(checkpoint_file)

    return trainer, training_errors, test_errors

Training

input_images, target_masks = input_images, target_masks = simulation.generate_random_data(192, 192, count=1024)

trainer, training_errors, test_errors = train(input_images, target_masks)
epoch #0: training_error=-0.017798021160390066, test_error=-0.018451113952323794
epoch #1: training_error=-0.1391007762240327, test_error=-0.14523974185188612
epoch #2: training_error=-0.3251049741454746, test_error=-0.3291884511709213
epoch #3: training_error=-0.40855012069577756, test_error=-0.41476351022720337
epoch #4: training_error=-0.44601072746774423, test_error=-0.4511391098300616
epoch #5: training_error=-0.4810214545415795, test_error=-0.48489775508642197
epoch #6: training_error=-0.5151172067808069, test_error=-0.5200231522321701
epoch #7: training_error=-0.5922727802525396, test_error=-0.5973579933245977
epoch #8: training_error=-0.749630199826282, test_error=-0.7541888852914175
epoch #9: training_error=-0.7754635240720666, test_error=-0.778565322359403
epoch #10: training_error=-0.8706006376639657, test_error=-0.8741355637709299
epoch #11: training_error=-0.9253758440846982, test_error=-0.9278035958607992
epoch #12: training_error=-0.9409363124681556, test_error=-0.943161795536677
epoch #13: training_error=-0.9504859722178916, test_error=-0.9518442749977112
epoch #14: training_error=-0.9561804066533628, test_error=-0.9564324915409088
epoch #15: training_error=-0.9596312388129856, test_error=-0.958900640408198
epoch #16: training_error=-0.9619116700213889, test_error=-0.9606296718120575
epoch #17: training_error=-0.963296625925147, test_error=-0.9618712464968363
epoch #18: training_error=-0.964468306562175, test_error=-0.962962140639623
epoch #19: training_error=-0.9656051786049552, test_error=-0.9633625497420629
epoch #20: training_error=-0.9661645360614942, test_error=-0.9637840042511622
epoch #21: training_error=-0.9670840688373732, test_error=-0.9645407944917679
epoch #22: training_error=-0.9675297908160998, test_error=-0.9647647142410278
epoch #23: training_error=-0.968075982902361, test_error=-0.9654373526573181
epoch #24: training_error=-0.9680173241573832, test_error=-0.9652755657831827
epoch #25: training_error=-0.96848623752594, test_error=-0.9659683257341385
epoch #26: training_error=-0.9682907306629679, test_error=-0.9664795845746994
epoch #27: training_error=-0.9695260338161302, test_error=-0.9665666818618774
epoch #28: training_error=-0.969839212168818, test_error=-0.9669433981180191
epoch #29: training_error=-0.9700202615364738, test_error=-0.9667912622292837
epoch #30: training_error=-0.9708342692126398, test_error=-0.9675299723943075
epoch #31: training_error=-0.9703854773355567, test_error=-0.9673667003711065
epoch #32: training_error=-0.9717840562696042, test_error=-0.9684023261070251
epoch #33: training_error=-0.9726218985474628, test_error=-0.9691992004712423
epoch #34: training_error=-0.9721553678097932, test_error=-0.9685578594605128
epoch #35: training_error=-0.9730600165284198, test_error=-0.9691728303829829
epoch #36: training_error=-0.9736596802006597, test_error=-0.9698172907034556
epoch #37: training_error=-0.9731561370517896, test_error=-0.9691229710976282
epoch #38: training_error=-0.9742445463719576, test_error=-0.9703827102979025
epoch #39: training_error=-0.972710659192956, test_error=-0.9692197690407435
epoch #40: training_error=-0.9743008660233539, test_error=-0.9704541166623434
epoch #41: training_error=-0.9747222724168197, test_error=-0.9709257930517197
epoch #42: training_error=-0.9754152588222338, test_error=-0.9714237848917643
epoch #43: training_error=-0.9743199861567954, test_error=-0.9697967072327932
epoch #44: training_error=-0.9753414858942446, test_error=-0.9713153938452402
epoch #45: training_error=-0.9763206186501876, test_error=-0.9717517246802648
epoch #46: training_error=-0.9767339353976042, test_error=-0.9718629717826843
epoch #47: training_error=-0.972210144996643, test_error=-0.9703837434450785
epoch #48: training_error=-0.9680927250696265, test_error=-0.967069461941719
epoch #49: training_error=-0.9752375457597815, test_error=-0.9707983434200287

Learning curve (Training/Test error)

helper.plot_errors({"training": training_errors, "test": test_errors}, title="Simulation Learning Curve")

Learning curve

Use the trained model

# Generate some random images
input_images, target_masks = input_images, target_masks = simulation.generate_random_data(192, 192, count=10)

# Predict
pred = trainer.model.eval(input_images)

print(input_images.shape, target_masks.shape, pred.shape)
(10, 1, 192, 192) (10, 6, 192, 192) (10, 6, 192, 192)
# Change channel-order and make 3 channels for matplot
input_images_rgb = [(x.swapaxes(0, 2).swapaxes(0,1).repeat(3, axis=2) * -255 + 255).astype(np.uint8) for x in input_images]

# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]
pred_rgb = [helper.masks_to_colorimg(x) for x in pred]

# Left: Input image, Middle: Correct mask (Ground-truth), Rigth: Predicted mask
helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])

Left: Input image, Middle: Correct mask (Ground-truth), Rigth: Predicted mask

Predicted masks from the trained model

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.