Giter Club home page Giter Club logo

ultrasound-nerve-segmentation's Introduction

ultrasound-nerve-segmentation

Kaggle ultrasound nerve segmentation challenge using Keras. Read my blog for details and insights.

#Install (Ubuntu {14,16}, GPU)

cuDNN required.

###Tensorflow backend

###Keras

  • sudo apt-get install libhdf5-dev
  • sudo pip install h5py
  • sudo pip install keras

In ~/.keras/keras.json

{
    "image_dim_ordering": "th",
    "epsilon": 1e-07,
    "floatx": "float32",
    "backend": "tensorflow"
}

###Python deps

  • sudo apt-get install python-opencv
  • sudo apt-get install python-sklearn

#Prepare Download the data from https://www.kaggle.com/c/ultrasound-nerve-segmentation/data and place it in input/train and input/test folders respectively.

Run

python data.py

to generate data within input folder. This is a one time only operation,

#Training

python train.py

Results will be generated in "results/" folder. results/net.hdf5 - best model

#Submission

python submission.py

will generate submission with run length encoding that can directly be uploaded to kaggle.

#Model

I used U-net like architecture (http://arxiv.org/abs/1505.04597) with a few tweaks.

  • Main idea was to use two training heads, one optimizing bce for nerve presence and other optimizing dice for segmentation. During test time simply zero out masks that have probability < 0.5. This was necessary because large number of samples contained no masks, and bce/dice score alone would simply be optimized by outputting all zeros for masks.
  • Network contains ~8.25 million parameters. Single epoch took 4 minutes on a Titan X with 12 GB memory.
  • Reduced learning rate by factor of 0.25 when stagnation occurred within last 4 epochs.
  • Logs are written to 'logs/' folder and monitored via tensorboard. Examined histograms to detect saturation. Note that you need to use fixed set vs generator to get histograms, as of keras 1.1.1 due to a known issue.
  • Weight regularization prevented convergence (perhaps smaller lambda needed to be used). Used dropout instead to prevent weight saturation (which tended to occur without it)
  • he_normal weight initialization.
  • conv with 2 X 2 stride instead of max pooling to downsample, in light of recent results with VAE and GANs.
  • ELU activation, batchnorm everywhere.
  • Used 1 X 1 conv instead of dense layers in the spirit of paper - "Striving for simplicity - The all conv net".

Augmentation:

  • Parallel aug generation on CPU.
  • random rotation (+/- 5 deg)
  • random translations (+/- 10 px)
  • elastic deformation didn't help much.
  • Larger rotations/translations prevented learning.

Validation:

  • 10% of the examples, stratified split by mask/no-mask

Visual inspection:

  • utils.examine_generator() can be used to visually inspect augmented samples.
  • utils.inspect_set() can be used to examine test time predictions on train/val set.
  • I am in the process of generalizing layer visualization code. Otherwise, I used various gradient ascent style visualizations to sanity check if the network is learning the right thing.

#Credits Borrowed starter code from https://github.com/jocicmarko/ultrasound-nerve-segmentation/, particularly data prep and submission portion.

ultrasound-nerve-segmentation's People

Contributors

raghakot avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

ultrasound-nerve-segmentation's Issues

Getting no pixel values in mask.

Hi,

I changed the dimension ordering of keras to 'theano' with backend 'tensorflow', dropped all images in "input" folder and then ran your code. Code is executing without any error but I am not getting any pixel value (mask) for test set.

Even I ran your code after commenting line {65, 66} of file submission.py:
#if has_masks[i, 0] < 0.5:
#masks[i, 0] *= 0.

But still there is no pixel value (mask) for any image of test set. Please help me..

Thanks...

InvalidArgumentError: Incompatible shapes: [819200] vs. [5120]

Getting that error at line 65 of train.py:
model.fit_generator(train_generator, validation_data=val_generator, nb_val_samples=X_val.shape[0] * 2,
samples_per_epoch=X_train.shape[0], nb_epoch=nb_epoch, verbose=1,
callbacks=[model_checkpoint, reduce_lr, tb], max_q_size=10000)

At callbacks of above line.

I read documentation of fit_generation at keras official website and found that it works parllely with CPU and GPU. But I don't have GPU.
So, is this error is coming because of absence of GPU?

Thanks...

a question regarding customdatagenerator

Hi Raghavendra,

Thanks for sharing the code. I have a question regarding generator.py. Looks like CustomDataGenerator(Iterator): only contains two methods: init and next. You mentioned that this is modified Modified keras ImageDataGenerator, which seems to me includes a lot of other methods.

Or in specific, I am just trying to understand how does your CustomDataGenerator(Iterator) work?. Thanks.

Besides, in train.py, the train_generator and val_generator are setup in different manner, may I know the difference?

train_generator = CustomDataGenerator(X_train, y_train, transform, batch_size)

val_generator = CustomDataGenerator(X_val, y_val, lambda x, y: transform(x, y, augment=False), batch_size)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.