Giter Club home page Giter Club logo

tabbench's Introduction


PyTorch - Version Python - Version

TALENT: A Tabular Analytics and Learning Toolbox

[Paper] [中文解读]


🎉 Introduction

Welcome to TALENT, a benchmark with a comprehensive machine learning toolbox designed to enhance model performance on tabular data. TALENT integrates advanced deep learning models, classical algorithms, and efficient hyperparameter tuning, offering robust preprocessing capabilities to optimize learning from tabular datasets. The toolbox is user-friendly and adaptable, catering to both novice and expert data scientists.

TALENT offers the following advantages:

  • Diverse Methods: Includes various classical methods, tree-based methods, and the latest popular deep learning methods.
  • Extensive Dataset Collection: Equipped with 300 datasets, covering a wide range of task types, size distributions, and dataset domains.
  • Customizability: Easily allows the addition of datasets and methods.
  • Versatile Support: Supports diverse normalization, encoding, and metrics.

📚Citing TALENT

If you use any content of this repo for your work, please cite the following bib entry:

@article{ye2024closerlookdeeplearning,
         title={A Closer Look at Deep Learning on Tabular Data}, 
         author={Han-Jia Ye and Si-Yang Liu and Hao-Run Cai and Qi-Le Zhou and De-Chuan Zhan},
         journal={arXiv preprint arXiv:2407.00956},
         year={2024}
}

@article{liu2024talenttabularanalyticslearning,
         title={TALENT: A Tabular Analytics and Learning Toolbox}, 
         author={Si-Yang Liu and Hao-Run Cai and Qi-Le Zhou and Han-Jia Ye},
         journal={arXiv preprint arXiv:2407.04057},
         year={2024}
}

📰 What's New

  • [2024-07]🌟 Add RealMLP.
  • [2024-07]🌟 Add ProtoGate (ICML 2024).
  • [2024-07]🌟 Add BiSHop (ICML 2024).
  • [2024-06]🌟 Check out our new baseline ModernNCA, inspired by traditional Neighbor Component Analysis, which outperforms both tree-based and other deep tabular models, while also reducing training time and model size!
  • [2024-06]🌟 Check out our benchmark paper about tabular data, which provides comprehensive evaluations of classical and deep tabular methods based on our toolbox in a fair manner!

🌟 Methods

TALENT integrates an extensive array of 20+ deep learning architectures for tabular data, including but not limited to:

  • MLP: A multi-layer neural network, which is implemented according to RTDL.
  • ResNet: A DNN that uses skip connections across many layers, which is implemented according to RTDL.
  • SNN: An MLP-like architecture utilizing the SELU activation, which facilitates the training of deeper neural networks.
  • DANets: A neural network designed to enhance tabular data processing by grouping correlated features and reducing computational complexity.
  • TabCaps: A capsule network that encapsulates all feature values of a record into vectorial features.
  • DCNv2: Consists of an MLP-like module combined with a feature crossing module, which includes both linear layers and multiplications.
  • NODE: A tree-mimic method that generalizes oblivious decision trees, combining gradient-based optimization with hierarchical representation learning.
  • GrowNet: A gradient boosting framework that uses shallow neural networks as weak learners.
  • TabNet: A tree-mimic method using sequential attention for feature selection, offering interpretability and self-supervised learning capabilities.
  • TabR: A deep learning model that integrates a KNN component to enhance tabular data predictions through an efficient attention-like mechanism.
  • ModernNCA: A deep tabular model inspired by traditional Neighbor Component Analysis, which makes predictions based on the relationships with neighbors in a learned embedding space.
  • DNNR: Enhances KNN by using local gradients and Taylor approximations for more accurate and interpretable predictions.
  • AutoInt: A token-based method that uses a multi-head self-attentive neural network to automatically learn high-order feature interactions.
  • Saint: A token-based method that leverages row and column attention mechanisms for tabular data.
  • TabTransformer: A token-based method that enhances tabular data modeling by transforming categorical features into contextual embeddings.
  • FT-Transformer: A token-based method which transforms features to embeddings and applies a series of attention-based transformations to the embeddings.
  • TANGOS: A regularization-based method for tabular data that uses gradient attributions to encourage neuron specialization and orthogonalization.
  • SwitchTab: A self-supervised method tailored for tabular data that improves representation learning through an asymmetric encoder-decoder framework. Following the original paper, our toolkit uses a supervised learning form, optimizing both reconstruction and supervised loss in each epoch.
  • PTaRL: A regularization-based framework that enhances prediction by constructing and projecting into a prototype-based space.
  • TabPFN: A general model which involves the use of pre-trained deep neural networks that can be directly applied to any tabular task.
  • HyperFast: A meta-trained hypernetwork that generates task-specific neural networks for instant classification of tabular data.
  • TabPTM: A general method for tabular data that standardizes heterogeneous datasets using meta-representations, allowing a pre-trained model to generalize to unseen datasets without additional training.
  • BiSHop: An end-to-end framework for deep tabular learning which leverages a sparse Hopfield model with adaptable sparsity, enhanced by column-wise and row-wise modules.
  • ProtoGate: A prototype-based model for feature selection in HDLSS biomedical data that adapts global and local feature selection to enhance prediction accuracy and interpretability, addressing co-adaptation issues through a non-parametric prototype-based mechanism.
  • RealMLP: An improved multilayer perceptron (MLP).

☄️ How to Use TALENT

🕹️ Clone

Clone this GitHub repository:

git clone https://github.com/qile2000/LAMDA-TALENT
cd LAMDA-TALENT/LAMDA-TALENT

🔑 Run experiment

  1. Edit the configs/default/[MODEL_NAME].json and config/opt_space/[MODEL_NAME].json for global settings and hyperparameters.

  2. Run:

    python train_model_deep.py --model_type MODEL_NAME

    for deep methods, or:

    python train_model_classical.py --model_type MODEL_NAME

    for classical methods.

🛠️How to Add New Methods

For methods like the MLP class that only need to design the model, you only need to:

  • Add the model class in model/models.
  • Inherit from model/methods/base.py and override the construct_model() method in the new class.
  • Add the method name in the get_method function in model/utils.py.
  • Add the parameter settings for the new method in configs/default/[MODEL_NAME].json and configs/opt_space/[MODEL_NAME].json.

For other methods that require changing the training process, partially override functions based on model/methods/base.py. For details, refer to the implementation of other methods in model/methods/.

📦 Dependencies

  1. torch

  2. scikit-learn

  3. pandas

  4. tqdm

  5. numpy

  6. scipy

  7. If you want to use TabR, you have to manually install faiss, which is only available on conda:

    conda install faiss-gpu -c pytorch

🗂️ Benchmark Datasets

Datasets are available at Google Drive.

📂How to Place Datasets

Datasets are placed in the project's current directory, corresponding to the file name specified by args.dataset_path. For instance, if the project is LAMDA-TALENT, the data should be placed in LAMDA-TALENT/args.dataset_path/args.dataset.

Each dataset folder args.dataset consists of:

  • Numeric features: N_train/val/test.npy (can be omitted if there are no numeric features)

  • Categorical features: C_train/val/test.npy (can be omitted if there are no categorical features)

  • Labels: y_train/val/test.npy

  • info.json, which must include the following three contents (task_type can be "regression", "multiclass" or "binclass"):

    {
      "task_type": "regression", 
      "n_num_features": 10,
      "n_cat_features": 10
    }

📝 Experimental Results

We provide comprehensive evaluations of classical and deep tabular methods based on our toolbox in a fair manner in the Figure. Three tabular prediction tasks, namely, binary classification, multi-class classification, and regression, are considered, and each subfigure represents a different task type.

We use Accuracy and RMSE as the metrics for classification tasks and regression tasks, respectively. To calibrate the metrics, we choose the average performance rank to compare all methods, where a lower rank indicates better performance, following Sheskin (2003). Efficiency is calculated by the average training time in seconds, with lower values denoting better time efficiency. The model size is visually indicated by the radius of the circles, offering a quick glance at the trade-off between model complexity and performance.

The classical method SVM provided in TALENT is a LinearSVM to ensure faster training. We also consider the Dummy baseline, which outputs the label of the major class and the average labels for classification and regression tasks, respectively.

  • Binary classification

  • Multiclass Classification

  • Regression

  • All tasks

From the comparison, we observe that CatBoost achieves the best average rank in most classification and regression tasks. Among all deep tabular methods, ModernNCA performs the best in most cases while maintaining an acceptable training cost. These results highlight the effectiveness of CatBoost and ModernNCA in handling various tabular prediction tasks, making them suitable choices for practitioners seeking high performance and efficiency.

These visualizations serve as an effective tool for quickly and fairly assessing the strengths and weaknesses of various tabular methods across different task types, enabling researchers and practitioners to make informed decisions when selecting suitable modeling techniques for their specific needs.

👨‍🏫 Acknowledgments

We thank the following repos for providing helpful components/functions in our work:

🤗 Contact

If there are any questions, please feel free to propose new features by opening an issue or contact the author: Siyang Liu ([email protected]) and Haorun Cai ([email protected]) and Qile Zhou ([email protected]) and Han-Jia Ye ([email protected]). Enjoy the code.

🚀 Star History

Star History Chart

Thanks LAMDA-PILOT and LAMDA-ZhiJian for the template.

tabbench's People

Contributors

6sy666 avatar mkmaa avatar qile2000 avatar

Stargazers

 avatar  avatar Rabby avatar Richard Wu avatar Arseniy Sokolov avatar David Holzmüller avatar  avatar Yao Yao avatar TAISEI Tosaki avatar Anylife178600 avatar 李子凡 avatar  avatar MahoneLau avatar Tigger avatar 厦门飞骥科技有限公司 avatar  avatar liutao avatar Wind avatar  avatar Jon Chun avatar Lex avatar highbro avatar ViporMinerProxy avatar Lujia Jin avatar Kitty avatar  avatar Xu Yichu avatar  avatar LedArx avatar Dyang avatar  avatar  avatar 自由的世界人 avatar Patryk Mauer avatar Ash avatar Nick PISHENG avatar TellMeWhy1122 avatar Chenghao Mou avatar LOL avatar Suiyao Chen avatar JM Li avatar  avatar jingang avatar wjp avatar  avatar  avatar  avatar Vladyslav Khramtsov avatar  avatar  avatar lataku avatar nathaniel_max avatar  avatar zhangkejiang avatar Tailor3D avatar  avatar kkk avatar  avatar Xinjie Shen avatar Finn Lau avatar brew avatar diupy avatar Williamcloudq avatar Aurore White avatar Rongxi Tan avatar  avatar 圆眼睛的阿凡提哥哥 avatar Tsunami avatar Eason Shen avatar Ray Wang avatar  avatar Flipped avatar lomio avatar  avatar TimeFunny avatar Qcy avatar  avatar 爱可可-爱生活 avatar awa avatar Jialue Chen avatar Oshikawa Yuri avatar Hulk avatar Zezhong Li avatar  avatar Jie-Jing Shao avatar zhanghao5683934 avatar MeiCXi avatar dh avatar Woo avatar  avatar Ma Shiqing avatar  avatar  avatar  avatar  avatar Xurui Li avatar liheng avatar  avatar Qi-Wei Wang avatar Liu Haitian avatar

Watchers

 avatar Yao Yao avatar Lujia Jin avatar  avatar MahoneLau avatar a7 avatar Lex avatar  avatar

tabbench's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.