Giter Club home page Giter Club logo

ec-gan's Introduction

This repository contains the implementation of the paper "EC-GAN: Low-Sample Classification using Semi-Supervised Algorithms and GANs" by Ayaan Haque from Saratoga High School. In AAAI, 2021.

Our proposed model combines a Generative Adversarial Network with a classifier to leverage artifical GAN generations to increase the size of restricted, fully-supervised datasets in a semi-supervised method.

Thumbnail
Watch Oral Presentation

Abstract

Semi-supervised learning has been gaining attention as it allows for performing image analysis tasks such as classification with limited labeled data. Some popular algorithms using Generative Adversarial Networks (GANs) for semi-supervised classification share a single architecture for classification and discrimination. However, this may require a model to converge to a separate data distribution for each task, which may reduce overall performance. While progress in semi-supervised learning has been made, less addressed are small-scale, fully-supervised tasks where even unlabeled data is unavailable and unattainable. We therefore, propose a novel GAN model namely External Classifier GAN (EC-GAN), that utilizes GANs and semi-supervised algorithms to improve classification in fully-supervised regimes. Our method leverages a GAN to generate artificial data used to supplement supervised classification. More specifically, we attach an external classifier, hence the name EC-GAN, to the GAN's generator, as opposed to sharing an architecture with the discriminator. Our experiments demonstrate that EC-GAN's performance is comparable to the shared architecture method, far superior to the standard data augmentation and regularization-based approach, and effective on a small, realistic dataset.

Model

Figure

We propose an algorithm to improve classification utilizing GAN-generated images in restricted, fully-supervised regimes. Our approach consists of three separate models: a generator, a discriminator, and a classifier. At every training iteration, the generator is given random vectors and generates corresponding images. The discriminator is then updated to better distinguish between real and generated samples.

Simultaneously, a classifier is trained in standard fashion on available real data and their respective labels (note these datasets are fully labeled). We then use generated images as inputs for supplementing classification during training. This is the semi-supervised portion of our algorithm, as the generated images do not have associated labels. To create labels, we use a pseudo-labeling scheme which assumes a label based on the most likely class according to the current state of the classifier. The generated images and labels are only retained if the model predicts the class of the sample with high confidence, or a probability above a certain threshold. This loss is multiplied by a hyperparameter, which controls the relative importance of generated data compared to true samples.

Note that this classifier is its own network, as opposed to a shared architecture with the discriminator. This is a key contribution of our paper, as most GAN-based classification methods employ a shared discriminator-classifier architecture. We aim to empirically show that an external classifier performs better than a shared architecture.

Results

A brief summary of the results are shown below. EC-GAN is compared to the shared architecture method on SVHN at different dataset sizes. The left value is the accuracy of a standard classifier (same architecture as GAN counterpart), followed by the accuracy of the GAN classification algorithm.

Figure

Code

The code has been written in Python using the Pytorch framework. Training requries a GPU. We provide a Jupyter Notebook, which can be run in Google Colab, containing the algorithm in a usable version. Open EC-GAN.ipynb and run it through. The notebook includes annotations to follow along.

Citation

If you find this repo or the paper useful, please cite:

Ayaan Haque, "EC-GAN: Low-Sample Classification using Semi-Supervised Algorithms and GANs," 2021.

@article{Haque_2021, 
      title={EC-GAN: Low-Sample Classification using Semi-Supervised Algorithms and GANs}, 
      volume={35}, 
      url={https://ojs.aaai.org/index.php/AAAI/article/view/17895},
      number={18}, 
      journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 
      author={Haque, Ayaan}, 
      year={2021}, 
      month={May}, 
      pages={15797-15798} 
}

ec-gan's People

Contributors

ayaanzhaque avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

ec-gan's Issues

Some questions about the quality of the images generated by the generator

after 100 epochs, the quality of fake image is still far inferior to the real ones, and I guess this is the reason why my experiment could not achieve the performance mentioned in the paper(86% in SVHN 10% datasize not 91.15%). So I would like to ask you if you have any tips when training the generator? Thanks for the answer!

Reproduce results paper

I have been trying to reproduce the results of the paper with 10% of the SVHN dataset. Unfortunately, the EC-GAN from the notebook does not reach the 91.15% as mentioned in the paper, it is even worse than the standard classifier. I might be doing something wrong, so I am curious to find out what.

EC-GAN
Accuracy of the network on the 26032 test images: 85 ##
Train Accuracy: 98.4375

Standard classifier
Accuracy of the network on the 26032 test images: 88 ##
Train Accuracy: 98.4375

Regarding the notebook/paper I also have following questions:

  1. Which confidence threshold was used in the paper?
  2. What was the reason to have an adversarial weight of 0.1. Was this value optimal?
  3. Why did you reload the dataset in the notebook twice? Since no seed is used, this would mean that the input data of both networks is different.
  4. What is the reason for a different LR settings for the classifier GAN and standard classifier?

the question about the pseudo annotation.

hi, thanks for your work.
I notice that to label the generated unknown image, you had employed the classifier combined with a predefined threshold to filter low-probability generated image. I wonder have you ever considered to use cGAN, which may be a better choice in generating image with condition (specific category)? thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.