Giter Club home page Giter Club logo

resnet-classifier's Introduction

ResNet-50 Image Classifier

This is an Image Classifier that follows the Residual Network architecture with 50 layers that can be used to classify objects from among 101 different categories with a high accuracy. In recent years, neural networks have become deeper, with state-of-the-art networks going from just a few layers (e.g., AlexNet) to over a hundred layers. The main benefit of a very deep network is that it can represent very complex functions. It can also learn features at many different levels of abstraction, from edges (at the lower layers) to very complex features (at the deeper layers). However, using a deeper network doesn't always help. A huge barrier to training them is vanishing gradients: very deep networks often have a gradient signal that goes to zero quickly, thus making gradient descent unbearably slow.

In ResNets, a "shortcut" or a "skip connection" allows the gradient to be directly backpropagated to earlier layers:

The "identity block" is the standard block used in ResNets, and corresponds to the case where the input activation (say a[l]) has the same dimension as the output activation (say a[l+2]):

Next, the ResNet "convolutional block" is the other type of block. You can use this type of block when the input and output dimensions don't match up. The difference here is that there is a CONV2D layer in the shortcut path:

The detailed structure of this ResNet-50 model (I've also added additional Batch-Norm and Dropout Layers since they're absolutely awesome):

ResNet-50

Dataset

We'll be using the Caltech 101 Multi-Classification Object Dataset. The Dataset is a collection of pictures of objects belonging to 101 different categories. About 40 to 800 images per category. Most categories have about 50 images. Since the number of images for each class is less than sufficient for training our model we'll be using Data Augmentation to obtain more images. This is our directory structure:

datasets/
    caltech_101/
        train/
            accordian/
                image_0001.jpg
                image_0002.jpg
                ...
            airplanes/
                image_0001.jpg
                image_0002.jpg
                ...
            ...

Our preprocessing script prepro.py will handle the rest.

Getting Started

In order to train the model and make predictions, you will need to install the required python packages using:

pip install -r requirements.txt

Now, we need to do some preprocessing (Data Augmentation and Train/Val/Test Split) of our dataset:

python prepro.py --dataset-path datasets/caltech_101

Once you're done with all that, you can open up a terminal and start training the model (FYI: it takes a while):

python train.py -lr 0.005 --num-epochs 50 --batch-size 32 --save-every 5 --tensorboard-vis

Passing the --tensorboard-vis flag allows you to view the training/validation loss and accuracy in you browser using:

tensorboard --logdir=./logs

Once you're done training run the prediction script which will load the pretrained model and make a prediction on your test image:

python predict.py images/test.jpg

Results

Training:

number of training examples: 72893
X_train shape: (72893, 64, 64, 3)
Y_train shape: (72893, 101)
number of validation examples: 2020
X_train shape: (2020, 64, 64, 3)
Y_train shape: (2020, 101)
Epoch 50/50:
2187/2187 [==============================] - 393s 175ms/step - loss: 0.0341 - acc: 0.9888 - val_loss: 0.3118 - val_acc: 0.9311

Val Loss = 0.3118
Val Accuracy = 93.11% (0.9311)

Testing:

number of test examples: 2020
X_test shape: (2020, 64, 64, 3)
Y_test shape: (2020, 101)
68/68 [==============================] - 30s 437ms/step
Loss = 0.3213
Test Accuracy = 92.97% (0.9297)

Model Parameters:

Total params: 23,794,661
Trainable params: 23,741,541
Non-trainable params: 53,120

Built With

  • Python
  • Keras
  • TensorFlow
  • NumPy

resnet-classifier's People

Contributors

pskrunner14 avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.