Giter Club home page Giter Club logo

tabular-dl-num-embeddings's Introduction

On Embeddings for Numerical Features in Tabular Deep Learning

This is the official implementation of the paper "On Embeddings for Numerical Features in Tabular Deep Learning" (arXiv).

Check out other projects on tabular Deep Learning: link.

Feel free to report issues and post questions/feedback/ideas.



The main results

See bin/results.ipynb.

Set up the environment

Software

Preliminaries:

  • You may need to change the CUDA-related commands and settings below depending on your setup
  • Make sure that /usr/local/cuda-11.1/bin is always in your PATH environment variable
  • Install conda
export PROJECT_DIR=<ABSOLUTE path to the repository root>
# example: export PROJECT_DIR=/home/myusername/repositories/num-embeddings
git clone https://github.com/Yura52/tabular-dl-num-embeddings $PROJECT_DIR
cd $PROJECT_DIR

conda create -n num-embeddings python=3.9.7
conda activate num-embeddings

pip install torch==1.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt

# if the following commands do not succeed, update conda
conda env config vars set PYTHONPATH=${PYTHONPATH}:${PROJECT_DIR}
conda env config vars set PROJECT_DIR=${PROJECT_DIR}
# the following command appends ":/usr/local/cuda-11.1/lib64" to LD_LIBRARY_PATH;
# if your LD_LIBRARY_PATH already contains a path to some other CUDA, then the content
# after "=" should be "<your LD_LIBRARY_PATH without your cuda path>:/usr/local/cuda-11.1/lib64"
conda env config vars set LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda-11.1/lib64
conda env config vars set CUDA_HOME=/usr/local/cuda-11.1
conda env config vars set CUDA_ROOT=/usr/local/cuda-11.1

# (optional) get a shortcut for toggling the dark mode with cmd+y
conda install nodejs
jupyter labextension install jupyterlab-theme-toggle

conda deactivate
conda activate num-embeddings

Data

LICENSE: by downloading our dataset you accept the licenses of all its components. We do not impose any new restrictions in addition to those licenses. You can find the list of sources in the paper.

cd $PROJECT_DIR
wget "https://www.dropbox.com/s/r0ef3ij3wl049gl/data.tar?dl=1" -O num_embeddings_data.tar
tar -xvf num_embeddings_data.tar

How to reproduce results

The code below reproduces the results for MLP on the California Housing dataset. The pipeline for other algorithms and datasets is absolutely the same.

# You must explicitly set CUDA_VISIBLE_DEVICES if you want to use GPU
export CUDA_VISIBLE_DEVICES="0"

# Create a copy of the 'official' config
cp exp/mlp/california/0_tuning.toml exp/mlp/california/1_tuning.toml

# Run tuning (on GPU, it takes ~30-60min)
python bin/tune.py exp/mlp/california/1_tuning.toml

# Evaluate single models with 15 different random seeds
python bin/evaluate.py exp/mlp/california/1_tuning 15

# Evaluate ensembles (by default, three ensembles of size five each)
python bin/ensemble.py exp/mlp/california/1_evaluation

# Then use bin/results.ipynb to view the obtained results

Understanding the repository

Code overview

The code is organized as follows:

  • bin
    • train4.py for neural networks (it implements all the embeddings and backbones from the paper)
    • xgboost_.py for XGBoost
    • catboost_.py for CatBoost
    • tune.py for tuning
    • evaluate.py for evaluation
    • ensemble.py for ensembling
    • results.ipynb for summarizing results
    • datasets.py was used to build the dataset splits
    • synthetic.py for generating the synthetic GBDT-friendly datasets
    • train1_synthetic.py for the experiments with synthetic data
  • lib contains common tools used by programs in bin
  • exp contains experiment configs and results (metrics, tuned configurations, etc.). The names of the nested folders follow the names from the paper (example: exp/mlp-plr corresponds to the MLP-PLR model from the paper).

Technical notes

  • You must explicitly set CUDA_VISIBLE_DEVICES when running scripts
  • for saving and loading configs, use lib.dump_config and lib.load_config instead of bare TOML libraries

Running scripts

The common pattern for running scripts is:

python bin/my_script.py a/b/c.toml

where a/b/c.toml is the input configuration file (config). The output will be located at a/b/c. The config structure usually follows the Config class from bin/my_script.py.

There are also scripts that take command line arguments instead of configs (e.g. bin/{evaluate.py,ensemble.py}).

train0.py vs train1.py vs train3.py vs train4.py

You need all of them for reproducing results, but you need only train4.py for future work, because:

  • bin/train1.py implements a superset of features from bin/train0.py
  • bin/train3.py implements a superset of features from bin/train1.py
  • bin/train4.py implements a superset of features from bin/train3.py

To see which one of the four scripts was used to run a given experiment, check the "program" field of the corresponding tuning config. For example, here is the tuning config for MLP on the California Housing dataset: exp/mlp/california/0_tuning.toml. The config indicates that bin/train0.py was used. It means that the configs in exp/mlp/california/0_evaluation are compatible specifically with bin/train0.py. To verify that, you can copy one of them to a separate location and pass to bin/train0.py:

mkdir exp/tmp
cp exp/mlp/california/0_evaluation/0.toml exp/tmp/0.toml
python bin/train0.py exp/tmp/0.toml
ls exp/tmp/0

How to cite

@article{gorishniy2022embeddings,
    title={On Embeddings for Numerical Features in Tabular Deep Learning},
    author={Yury Gorishniy and Ivan Rubachev and Artem Babenko},
    journal={arXiv},
    volume={2203.05556},
    year={2022},
}

tabular-dl-num-embeddings's People

Contributors

yura52 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.