Giter Club home page Giter Club logo

neural-implicit-dict's Introduction

Neural Implicit Dictionary via Mixture-of-Expert Training

License: MIT

The official implementation of ICML 2022 paper Neural Implicit Dictionary Learning via Mixture-of-Expert Training.

Peihao Wang, Zhiwen Fan, Tianlong Chen, Zhangyang (Atlas) Wang

Introduction

Representing visual signals by coordinate-based deep fully-connected networks has been shown advantageous in fitting complex details and solving inverse problems than discrete grid-based representation. However, acquiring such a continuous Implicit Neural Representation (INR) requires tedious per-scene training on tons of signal measurements, which limits its practicality. In this paper, we present a generic INR framework that achieves both data and training efficiency by learning a Neural Implicit Dictionary (NID) from a data collection and representing INR as a functional combination of basis sampled from the dictionary. Our NID assembles a group of coordinate-based subnetworks which are tuned to span the desired function space. After training, one can instantly and robustly acquire an unseen scene representation by solving the coding coefficients. To parallelly optimize a large group of networks, we borrow the idea from Mixture-of-Expert (MoE) to design and train our network with a sparse gating mechanism. Our experiments show that, NID can improve reconstruction of 2D images or 3D scenes by 2 orders of magnitude faster with up to 98% less input data. We further demonstrate various applications of NID in image inpainting and occlusion removal, which are considered to be challenging with vanilla INR.

Getting Started

Installation

We recommend users to use conda to install the running environment. For basic usage, the following dependencies are required:

pytorch
torchvision
cudatoolkit
tensorboard
pytorch-geometric
opencv
imageio
imageio-ffmpeg
configargparse
scipy
matplotlib
tqdm
lpips

To enable 3D data loading and visualization, please install the following packages:

pytorch3d
open3d
plyfile
trimesh

Data Preparation

To run our code, you need to download CelebA and ShapeNet dataset. After downloading ShapeNet, please follow here to collate data.

Usage

Instant Image Regression

First of all, train a dictionary on the training set of one specified dataset:

python train_images.py --config <config_path> --data_dir <path_to_data> --gpuid <gpu_id> --log_dir <log_dir>

where <config_path> specifies the path to the configuration files. We provide an example in configs/celeba_train.txt.

To finetune coefficients on the evaluation set, one can add a flag --finetune:

python train_images.py --config <config_path> --gpuid <gpu_id> --log_dir <log_dir> --finetune

We experiment on CelebA dataset, which can be downloaded from Large-scale CelebFaces Attributes (CelebA) Dataset. A pre-trained checkpoint on CelebA can be downloaded from HuggingFace Hub.

Facial Image Inpainting

To generate corrupted images, user can use scripts:

python scripts/preprocess_face.py --data_dir <path_to_data>  --dataset <data_type>  --out_dir <path_to_save> <perturb_params>

Afterward, with the pre-trained dictionary, one can recover the corrupted images by finetune on the corrupted images with L1 loss function:

python train_images.py --config <config_path> --gpuid <gpu_id> --log_dir <log_dir> --loss_type l1  --finetune

Robust PCA on Video

To train a dictionary on video clips, one needs to first convert frames to invidual images and save them under a folder <path_to_data>. Then train the dictionary with the modified regularization:

python train_image.py --config <config_path> --data_dir <path_to_data> --gpuid <gpu_id> --log_dir <log_dir> --loss_type l1 --l1_exp 0.5 --loss_l1 0.01 --loss_cv 0.0

Computed Tomography (CT)

Similar to image regression, one needs to first train a dictionary on Shepp-Logan phantoms dataset, and then fit on new measurements.

python train_ct.py --config <config_path> --gpuid <gpu_id> --log_dir <log_dir> --num_thetas <num_views>

It is recommended to use dense view to pre-train the dictionary. When testing on new CT images, we can reconstruct CT through few views.

python train_ct.py --config <config_path> --gpuid <gpu_id> --log_dir <log_dir> --num_thetas <num_views> --finetune

Signed Distance Function

To train SDF, one needs first process dataset using convert_mesh_to_sdf.py, which pre-compute the signed distance values for each instance. Then run the following command to train dictionary (or fit coefficients with flag --finetune):

python train_sdf.py --config <config_path> --data_dir <path_to_data> --gpuid <gpu_id> --log_dir <log_dir>

We experiment on ShapeNet dataset. To prepare the data, please follow ShapeNet to download the raw data and R2N2 to acquire the data split.

Citation

If you find this work or our code implementation helpful for your own resarch or work, please cite our paper.

@inproceedings{wang2022inrdict,
  title={Neural Implicit Dictionary via Mixture-of-Expert Training},
  author={Wang, Peihao and Fan, Zhiwen and Chen, Tianlong and Wang, Zhangyang},
  booktitle={International Conference on Machine Learning},
  year={2022}
}

neural-implicit-dict's People

Contributors

peihaowang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Forkers

poonono

neural-implicit-dict's Issues

Requirements versions

Hi @peihaowang ,
I would like to experiment with your framework, but I am having some trouble with the dependencies mentioned in the repo's Readme
It would greatly help if you could provide the relevant requirements with their corresponding versions (for pip and/or conda)
Thank you!

Training config

Thanks for sharing the great work.

I try to rerun this code on the CelebA dataset, but I find that the training configs are not provided. Could you please upload the training config for CelebA?

Thanks~

Training on CIFAR-10

Hi,

I appreciate you sharing your awesome project!

Have you tried to train the NID on the CIFAR-10 dataset?

I've tried that but the convergence is stuck quickly.
If you succeed to train, can you share your experiments setting?

My setting is below

batch_size = 64
chunk_size = 32
lr = 1e-4
weight_decay = 0.
num_epochs = 100

num_layers = 4
hidden_dim = 256
num_topk = 64
num_experts = 512

Thank you and nice work !

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.