Giter Club home page Giter Club logo

shufflenet's Introduction

ShuffleNet

An implementation of ShuffleNet introduced in TensorFlow. According to the authors, ShuffleNet is a computationally efficient CNN architecture designed specifically for mobile devices with very limited computing power. It outperforms Google MobileNet by small error percentage at much lower FLOPs.

Link to the original paper: ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices

ShuffleNet Unit



Group Convolutions

The paper uses the group convolution operator. However, that operator is not implemented in TensorFlow backend. So, I implemented the operator using graph operations.

This issue was discussed here: Support Channel groups in convolutional layers #10482

Channel Shuffling



Channel Shuffling can be achieved by applying three operations:

  1. Reshaping the input tensor from (N, H, W, C) into (N, H, W, G, C').

  2. Performing matrix transpose operation on the two dimensions (G, C').

  3. Reshaping the tensor back into (N, H, W, C).

    N: Batch size, H: Feature map height, W: Feature map width, C: Number of channels, G: Number of groups, C': Number of channels / Number of groups

    Note that: The number of channels should be divisible by the number of groups.

Usage

Main Dependencies

Python 3 or above
tensorflow 1.3.0
numpy 1.13.1
tqdm 4.15.0
easydict 1.7
matplotlib 2.0.2

Train and Test

  1. Prepare your data, and modify the data_loader.py/DataLoader/load_data() method.
  2. Modify the config/test.json to meet your needs.

Run

python main.py --config config/test.json

Results

The model have successfully overfitted TinyImageNet-200 that was presented in CS231n - Convolutional Neural Networks for Visual Recognition. I'm working on ImageNet training..

Benchmarking

The paper has achieved 140 MFLOPs using the vanilla version. Using the group convolution operator implemented in TensorFlow, I have achieved approximately 270 MFLOPs. The paper counts multiplication+addition as one unit, so roughly dividing 270 by two, I have achieved what the paper proposes.

To calculate the FLOPs in TensorFlow, make sure to set the batch size equal to 1, and execute the following line when the model is loaded into memory.

tf.profiler.profile(
        tf.get_default_graph(),
        options=tf.profiler.ProfileOptionBuilder.float_operation(), cmd='scope')

TODO

  • Training on ImageNet dataset. In progress...

Updates

  • Inference and training are working properly.

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Acknowledgments

Thanks for all who helped me in my work and special thanks for my colleagues: Mo'men Abdelrazek, and Mohamed Zahran.

shufflenet's People

Contributors

mg2033 avatar msiam avatar pomonam avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

shufflenet's Issues

Number of flops

Hi !
Does someone knows why the number of flops is about twice the number obtained by the authors ?

Thanks

HELP why loss keep a high value base flower dataset

I user flower dataset(5 classes and 2500 for train and 500 for val) to create tfrecords file and as input to train,but loss can not reduce and validation acc keep 20% , it is my code have some bug when read tfrecords?

import tensorflow as tf
from tqdm import tqdm
import numpy as np
from utils import load_obj
import matplotlib.pyplot as plt

class Train:
"""Trainer class for the CNN.
It's also responsible for loading/saving the model checkpoints from/to experiments/experiment_name/checkpoint_dir"""

def __init__(self, sess, model, data, summarizer):
    self.sess = sess
    self.model = model
    self.args = self.model.args
    self.saver = tf.train.Saver(max_to_keep=self.args.max_to_keep,
                                keep_checkpoint_every_n_hours=10,
                                save_relative_paths=True)
    # Summarizer references
    self.data = data
    self.summarizer = summarizer

    # Initializing the model
    self.init = None
    self.__init_model()

    # Loading the model checkpoint if exists
    self.__load_imagenet_weights()
    self.__load_model()
    IMAGE_SIZE = 224
    NUM_CLASSES = 5

############################################################################################################
# Model related methods
def __init_model(self):
    print("Initializing the model...")
    self.init = tf.group(tf.global_variables_initializer())
    self.sess.run(self.init)
    print("Model initialized\n\n")

def save_model(self):
    """
    Save Model Checkpoint
    :return:
    """
    print("Saving a checkpoint")
    self.saver.save(self.sess, self.args.checkpoint_dir, self.model.global_step_tensor)
    print("Checkpoint Saved\n\n")

def __load_model(self):
    latest_checkpoint = tf.train.latest_checkpoint(self.args.checkpoint_dir)
    if latest_checkpoint:
        print("Loading model checkpoint {} ...\n".format(latest_checkpoint))
        self.saver.restore(self.sess, latest_checkpoint)
        print("Checkpoint loaded\n\n")
    else:
        print("First time to train!\n\n")

def __load_imagenet_weights(self):
    variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

    print("No pretrained ImageNet weights exist. Skipping...\n\n")

############################################################################################################
# Train and Test methods 

def read_and_decode(self,filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        # Defaults are not specified since both keys are required.
            features={
            'image': tf.FixedLenFeature([], tf.string),
            'target': tf.FixedLenFeature([], tf.int64),
        })

    # Convert from a scalar string tensor (whose single string has
    image = tf.image.decode_jpeg(features['image'], channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize_image_with_crop_or_pad(image, 224, 224)
    image = tf.clip_by_value(image, 0.0, 1.0)

    # Convert label from a scalar uint8 tensor to an int32 scalar.
    label = tf.cast(features['target'], tf.int32)

    return image, label

def train(self):   
   
    filename_queue = tf.train.string_input_producer(["/home/coolpad/juzhitao/shufflenet/mg2033/ShuffleNet/train1.tfrecords"])
    #train data
    image, label = self.read_and_decode(filename_queue)
    
    images, labels = tf.train.shuffle_batch([image, label], batch_size=50, num_threads=2,capacity=2500,min_after_dequeue=250)        
    
    init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
    self.sess.run(init_op)
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)  
    
    for cur_epoch in range(self.model.global_epoch_tensor.eval(self.sess) + 1, self.args.num_epochs + 1, 1):
        # Initialize tqdm          
       
        num_iterations = self.args.train_data_size // self.args.batch_size
        print("num_iterations:::::::::::",num_iterations,'      ','train_data_size=',self.args.train_data_size ,'batch_size:', self.args.batch_size)
        #tqdm_batch = tqdm([self.data.X_train,self.data.y_train], total=num_iterations,
        #                  desc="Epoch-" + str(cur_epoch) + "-")
                   

        # Initialize the current iterations
        cur_iteration = 0

        # Initialize classification accuracy and loss lists
        loss_list = []
        acc_list = []

    
        # Loop by the number of iterations
        print("#####################################cur_epoch==",cur_epoch)
        #for self.data.X_train, self.data.y_train in tqdm_batch:
        for step in tqdm(range(0,  num_iterations),initial=1, total=num_iterations):
            # Get the current iteration for summarizing it
            cur_step = self.model.global_step_tensor.eval(self.sess)
            
            image_train, lable_train = self.sess.run([images,labels])               
            #print(image_train)
            # Feed this variables to the network
            feed_dict = {self.model.X: images,
                         self.model.y: labels,
                         self.model.is_training: True
                         }
            # Run the feed_forward
            _, loss, acc = self.sess.run(
                [self.model.train_op, self.model.loss, self.model.accuracy],
                feed_dict=feed_dict)
            # Append loss and accuracy
            loss_list += [loss]
            acc_list += [acc]

            # Update the Global step
            self.model.global_step_assign_op.eval(session=self.sess,
                                                  feed_dict={self.model.global_step_input: cur_step + 1})

            #self.summarizer.add_summary(cur_step, summaries_merged=summaries_merged)

            if step >= num_iterations - 1:
                avg_loss = np.mean(loss_list)
                avg_acc = np.mean(acc_list)
                # summarize
                #summaries_dict = dict()
                #summaries_dict['loss'] = avg_loss
                #summaries_dict['acc'] = avg_acc

                # summarize
                #self.summarizer.add_summary(cur_step, summaries_dict=summaries_dict)

                # Update the Current Epoch tensor
                self.model.global_epoch_assign_op.eval(session=self.sess,
                                                       feed_dict={self.model.global_epoch_input: cur_epoch + 1})

                # Print in console
                #tqdm_batch.close()
                print("Epoch-" + str(cur_epoch) + " | " + "loss: " + str(avg_loss) + " -" + " acc: " + str(
                    avg_acc)[
                                                                                                       :7])
                # Break the loop to finalize this epoch
                #break

            # Update the current iteration
            cur_iteration += 1

        # Save the current checkpoint
        if cur_epoch % self.args.save_model_every == 0 and cur_epoch != 0:
            self.save_model()

        # Test the model on validation or test data
        if cur_epoch % self.args.test_every == 0:
            self.test('val')
            
    coord.request_stop()
    coord.join(threads)               

def test(self, test_type='val'):
   
    filename_queue = tf.train.string_input_producer(["/home/coolpad/juzhitao/shufflenet/mg2033/ShuffleNet/val1.tfrecords"])
    #val data
    image, label = self.read_and_decode(filename_queue)
    
    images, labels = tf.train.shuffle_batch([image, label], batch_size=50, num_threads=2,capacity=200,min_after_dequeue=50)        
    
    init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
    self.sess.run(init_op)
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=self.sess, coord=coord) 
    
    
    num_iterations = self.args.test_data_size // self.args.batch_size
    #tqdm_batch = tqdm(self.data.generate_batch(type=test_type), total=num_iterations,
    #                  desc='Testing')
   
    
    # Initialize classification accuracy and loss lists
    loss_list = []
    acc_list = []
    cur_iteration = 0

    #for X_batch, y_batch in tqdm_batch:
    for step in tqdm(range(0,  num_iterations),initial=1, total=num_iterations):
    
        image_val, label_val = self.sess.run([images,labels])
        # Feed this variables to the network
        feed_dict = {self.model.X: image_val,
                     self.model.y: label_val,
                     self.model.is_training: False
                     }
        # Run the feed_forward
        loss, acc = self.sess.run(
            [self.model.loss, self.model.accuracy],
            feed_dict=feed_dict)

        # Append loss and accuracy
        loss_list += [loss]
        acc_list += [acc]

        if step >= num_iterations - 1:
            avg_loss = np.mean(loss_list)
            avg_acc = np.mean(acc_list)
            print('Test results | test_loss: ' + str(avg_loss) + ' - test_acc: ' + str(avg_acc)[:7])
            #break

        cur_iteration += 1

pretrained model

Hi @MG2033,

Did you finish training on ImageNet?
Can your pre-training model be shared?

thanks!

How's the ImageNet training

I'm trying to train this model on ImageNet, but the loss seems to converge slowly after 60k iterations and the loss value is approximately 2.5.

Do you have lower loss value or the similar phenomena above?
I just want to check that I train this model correctly.
Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.