Giter Club home page Giter Club logo

markov-lipschitz-deep-learning's Introduction

Markov-Lipschitz Deep Learning (MLDL)


The general arechitecture of an MLDL neural network.

ML-AE Result TopoAE Result Vanilla AE Result
Comparison of training processes of three autoencoders on the Spheres dataset.


This is a PyTorch implementation of the MLDL paper :

@article{Li-MarLip-2020,
  title={Markov-Lipschitz Deep Learning},
  author={Stan Z Li and Zelin Zang and Lirong Wu},
  journal={arXiv preprint arXiv:2006.08256},
  year={2020}
}

The main features of MLDL for manifold learning and generation in comparison to other popular methods are summarized below:

MLDL(ours) AE/TopoAE MLLE ISOMAP t-SNE
Manifold Learning without decoder Yes No Yes Yes Yes
Learned NLDR model applicable to test data Yes Yes No No No
Able to generate data of learned manifold Yes No No No No
Compatible with other DL frameworks Yes No No No No
Scalable to large datasets Yes Yes No No No

The code includes the following modules:

  • Datasets (Swiss Roll, S-Curve, MNIST, Spheres)
  • Training for ML-Enc and ML-AE (ML-Enc + ML-Dec)
  • Test for manifold learning (ML-Enc)
  • Test for manifold generation (ML-Dec)
  • Visualization
  • Evaluation metrics
  • The compared methods include: AutoEncoder (AE), Topological AutoEncoder (TopoAE), Modified Locally Linear Embedding (MLLE), ISOMAP, t-SNE. (Note: We modified the original TopoAE source code to make it able to run the Swiss roll dataset by adding a swiss roll dataset generation function and modifying the network structure for fair comparison.)

Requirements

  • pytorch == 1.3.1
  • scipy == 1.4.1
  • numpy == 1.18.5
  • scikit-learn == 0.21.3
  • csv == 1.0
  • matplotlib == 3.1.1
  • imageio == 2.6.0

Description

  • main.py
    • SetParam() -- Parameters for training
    • Train() -- Train a new model (encoder and/or decoder)
    • Train_MultiRun() -- Run the training for multiple times, each with a different seed
    • Generation() -- Testing generation of new data of the learned manifold
    • Generalization() -- Testing dimension reduction from unseen data of the learned manifold
    • InlinePlot() -- Inline plot intermediate results during training
  • dataset.py
    • LoadData() -- Load data of selected dataset
  • loss.py
    • MLDL_Loss() -- Calculate six losses: ℒEnc, ℒDec, ℒAE, ℒlis, ℒpush, ℒang
  • model.py
    • Encoder() -- For latent feature extraction
    • Decoder() -- For generating new data on the learned manifold
  • eval.py -- Calculate performance metrics from results, each being the average of 10 seeds
  • utils.py
    • GIFPloter() -- Auxiliary tool for online plot
    • CompPerformMetrics() -- Auxiliary tool for evaluating metric
    • Sampling() -- Sampling in the latent space for generating new data on the learned manifold

Running the code

  1. Clone this repository
git clone https://github.com/westlake-cairi/Markov-Lipschitz-Deep-Learning
  1. Install the required dependency packages

  2. To get the results for 10 seeds, run

python main.py -MultiRun
  1. To get the metrics for ML-Enc and ML-AE
python eval.py -M ML-Enc
python eval.py -M ML-AE

The evaluation metrics are available in ./pic/PerformMetrics.csv

  1. To choose a dataset among SwissRoll, Scurve, MNIST, Spheres5500 and Spheres10000 for tow modes (ML-Enc and ML-AE)
python main.py -D "dataset name" -M "mode"
  1. To test the generalization to unseen data
python main.py -M Test

The results are available in ./pic/file_name/Test.png

  1. To test the manifold generation
python main.py -M Generation

The results are available in ./pic/file_name/Generation.png

Results

1. ML-Enc: Dimension reduction results -- embeddings in latent spaces

  • Swiss Roll and S-Curve

    A symbol √ or X represents a success or failure in unfolding the manifold. The methods in the upper-row ML-Enc succeed and by calculation, the ML-Enc best maintains the true aspect ratio.

  • MNIST (10 digits)

2. ML-Enc: Performance metrics for dimension reduction on Swiss Roll (800 points) data

This table demonstrates that the ML-Enc outperforms all the other 6 methods in all the evaluation metrics, particularly significant in terms of the isometry (LGD, RRE, Cont and Trust) and Lipschitz (K-Min and K-Max) related metrics.

#Succ L-KL RRE Trust Cont LGD K-Min K-Max MPE
ML-Enc 10 0.0184 0.000414 0.9999 0.9985 0.00385 1.00 2.14 0.0262
TopoAE 0 0.0349 0.022174 0.9661 0.9884 0.13294 1.27 189.95 0.1307
t-SNE 0 0.0450 0.006108 0.9987 0.9843 3.40665 11.1 1097.62 0.1071
MLLE 6 0.1251 0.030702 0.9455 0.9844 0.04534 7.37 238.74 0.1709
HLLE 6 0.1297 0.034619 0.9388 0.9859 0.04542 7.44 218.38 0.0978
LTSA 6 0.1296 0.034933 0.9385 0.9859 0.04542 7.44 215.93 0.0964
ISOMAP 6 0.0234 0.009650 0.9827 0.9950 0.02376 1.11 34.35 0.0429
LLE 0 0.1775 0.014249 0.9753 0.9895 0.04671 6.17 451.58 0.1400

3. ML-Enc: Ability to generalize on unseen data of the learned manifold

The learned ML-Enc network can unfold unseen data of the learned manifold, demonstrated using the Swiss-roll with a hole, whereas the compared methods cannot.

4. ML-AE: For dimension reduction and manifold data generation

In the learning phase, the ML-AE taking (a) the training data as input, output (b) embedding in the learned latent space, and then reconstruct back (c). In the generation phase, the ML-Dec takes (d) random input samples in the latent space, and maps the samples to the manifold (e).

5. ML-AE: Evolution of training evolution

The ML-AE training gradually unfolds the manifold from input layer to the latent layer and reconstructs the latent embedding back to data in the input space.

Feedback

If you have any issue about the implementation, please feel free to contact us by email:

markov-lipschitz-deep-learning's People

Contributors

jiazhenpeng avatar lirongwu avatar stanzli avatar zangzelin avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.