Giter Club home page Giter Club logo

spectral-repr-cnns's Introduction

Final Project for ECBM 4040 - Columbia University

Aarshay Jain, Jared Samet, Alex Wainger {aj2713, jss2272, atw2131}@columbia.edu

Spectral Representations for Convolutional Neural Networks

This repo is an implementation of Rippel, Snoek, and Adams 2015 (https://arxiv.org/pdf/1506.03767.pdf). The project contains python modules and notebooks to implement the three proposals of the paper and to replicate the key findings and experiments of the paper.

Final Report

A copy of our Final Report is included in this repo.

Requirements

The project was developed using Tensorflow 1.3.0 and NumPy 1.13. Certain notebooks require the Pillow 4.3 library to be installed (sudo pip3 install Pillow).

Since the code uses the NCHW format to perform convolutions, it will only run on a GPU-enabled machine.

The CIFAR-100 dataset does not come in batches. Loading the dataset will require a machine with at least 32 GB of RAM.

Running saved models

Two of the notebooks refer to the saved models which contain the weights for our best accuracies. Before running these, unzip the src/best_model_10.tar.gz and src/best_model_100.tar.gz files.

Code Organization

All code is located in the src folder. Within that folder, Python functions and classes that are shared between multiple notebooks are all located in the modules folder.

Notebooks

approximation-loss.ipynb - This notebook demonstrates spectral pooling and frequency dropout in action on a minibatch. It also replicates the results for the approximation loss from the original paper.

cnn_spectral_parameterization.ipynb - This notebook replicates the comparison of convergence time (measured in epochs) for traditionally vs spectrally parameterized CNNs.

figure2.ipynb - This notebook uses the spectral_pool function to downsample an image.

hyperparameter-search.ipynb - This notebook performs hyperparameter search on the CIFAR-10 dataset to identify the best hyperparameters.

full-training-10.ipynb - This notebook uses the best identified hyperparameters to train the network on the entire CIFAR-10 dataset and compute the test accuracy. It also shows the improved results we got from manual tuning of the hyperparameters. Before running this notebook, unzip src/best_model_10.tar.gz.

full-training-100.ipynb - This notebook uses the best identified hyperparameters to train the network on the entire CIFAR-100 dataset and compute the test accuracy. It also shows the improved results we got from manual tuning of the hyperparameters. Before running this notebook, unzip src/best_model_100.tar.gz.

Modules

cnn_with_spectral_parameterization.py - This class builds and trains the generic and deep CNN architectures as described in section 5.2 of the paper with and without spectral parameterization of the filter weights. It was adapted from the homework assignment for this class on CNNs.

cnn_with_spectral_pooling.py - This class builds the spectral pooling CNN architectures as described in section 5.1 of the paper. It was adapted from the homework assignment for this class on CNNs.

create_images.py - These functions allowed us to experiment with and understand the behavior of the Fourier transform as applied to sample images.

frequency_dropout.py - These functions implement the frequency dropout operation by creating a dropout mask with an input tensor for the truncation frequency.

image_generator.py - This class creates a data generator that supports image augmentation. It was adapted from the homework assignment for this class on CNNs.

layers.py - The classes in this file implement the various layers that we use to create the CNN architectures described in the paper. Some layers were adapted from the homework assignment on CNNs. The layers defined in this file are:

  • default_conv_layer: A standard convolutional layer with traditionally-parameterized weights
  • fc_layer: A fully connected dense layer
  • spectral_pool_layer: A layer implementing spectral pooling and frequency dropout
  • spectral_conv_layer: A convolutional layer with spectrally-parameterized weights
  • global_average_layer: A layer implementing global averaging as described in Lin et al.

spectral_pool.py - A function implementing spectral pooling that is shared by multiple sources

utils.py - Various utility functions. Some were adapted from the homework assignment on CNNs.

spectral-repr-cnns's People

Contributors

alexwainger avatar oracleofnj avatar thismlguy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

spectral-repr-cnns's Issues

About the ifft

In you code, your implementation is im_out = tf.real(tf.ifft2d(im_transformed))
But due to the truncate ops, I think tf.abs(~~~) is more reasonable.
Another question is that the pooling value scope has been changed due to the truncate ops.

Middle column for obeying the complex symmetries with even size of output

Thank you for sharing this implementation and your results. It is very helpful. I was wondering how do you compute the middle values for the spectral pooling (even case). I can see that the spectral representation (affter (r)fft and without DC component shifting to the center) has the low frequencies in the corners, and the frequencies are increasing on the way to the center. Thus, for the spectral pooling procedure, we obtain a compressed output by slicing the values from the corners. For the even case, I cannot figure our why you compute the middle values in the following way:

middle_left = tf.cast(0.5 ** 0.5, tf.complex64) *
            (images[:, :, n, :n] + images[:, :, -n, :n])

I tried to print the values in the frequency domain after (r)fft for tensorflow and pytorch but cannot find any formula how to compute the middle values and the one that you provided does not seem to be valid in my examples, for instance:

import torch
import tensorflow as tf
b = tf.constant([[1,2,3, 4], [0, -1,-3,4],[5,-2,1,3.9], [0.1, 2, 1, -1]], dtype=tf.float32)
sess.run(bfft)
array([[ 20.       +0. j,   4.1      +9.9j,  -3.7999992+0. j],
       [  2.1      +2.1j,   2.       -7.8j,  -6.1      +6.1j],
       [ 15.799999 +0. j,  -0.0999999+5.9j,   7.9999995+0. j],
       [  2.1      -2.1j, -14.       +0. j,  -6.1      -6.1j]],
      dtype=complex64)
# let's print full output of rfft without the half slicing
bTorch = torch.tensor([[1,2,3, 4], [0, -1,-3,4],[5,-2,1,3.9], [0.1, 2, 1, -1]])
torch.rfft(bTorch, signal_ndim=2, onesided=False)
tensor([[[ 20.0000,   0.0000],
         [  4.1000,   9.9000],
         [ -3.8000,   0.0000],
         [  4.1000,  -9.9000]],

        [[  2.1000,   2.1000],
         [  2.0000,  -7.8000],
         [ -6.1000,   6.1000],
         [-14.0000,  -0.0000]],

        [[ 15.8000,   0.0000],
         [ -0.1000,   5.9000],
         [  8.0000,   0.0000],
         [ -0.1000,  -5.9000]],

        [[  2.1000,  -2.1000],
         [-14.0000,   0.0000],
         [ -6.1000,  -6.1000],
         [  2.0000,   7.8000]]])

c = tf.cast(0.5 ** 0.5, tf.complex64)
sess.run(c)
(0.70710677+0j)
sess.run(c * (bfft[0,1] + bfft[0,1]))
(5.7982755+14.000713j)  # != (-3.7999992+0. j)

Could you please let me know how you found the formula for computing the middle values or where I could find more information about the corner cases?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.