Giter Club home page Giter Club logo

gmda's Introduction

GMDA: Generative Modeling Density Alignment

Python 3.9 License: MIT Version

GMDA is a Python package for generative modeling with density alignment. This README provides instructions for installation, usage, and key features of the package.

Table of Contents

  1. Installation
  2. Data Processing
  3. Model Training
  4. Generating Synthetic Data
  5. Metrics
  6. Command-Line Usage

Installation

Disclaimer: Installing this package will result in the installation of a specific version of PyTorch, which may not be compatible with every user's GPU driver. Before installation, please check the compatibility of the included PyTorch version with your GPU driver. If incompatible, you should create your Python environment with the PyTorch version best suited for your system. Visit the official PyTorch website to find the appropriate installation command for your setup.

Using a Virtual Environment

python -m venv env_gmda
source env_gmda/bin/activate
pip install .

Using conda

conda create -n env_gmda python=3.9
conda activate env_gmda
pip install .[conda]

For develpoment mode, use:

pip install -e .[conda]

Data Processing

GMDA provides flexible data processing capabilities through the DataProcessor class.

from gmda.data_utils import DataProcessor

# Define custom data loading and processing functions
def custom_data_loader(train: bool = True, **kwargs):
    # Your custom data loading logic here. Should retun a tuple of tabular data (X, y).
    pass

def custom_data_processor(data, train: bool = True, **kwargs):
    # Your custom data processing logic here. Should retun a tuple of tuples of processed train and test data ((X_train, y_train), (X_test, y_test)).
    pass

# Instantiate the DataProcessor
data_processor = DataProcessor(custom_data_loader, custom_data_processor)

# Create dataloaders
train_loader, val_loader, X, y = data_processor.create_dataloaders(batch_size=64, density=0.1)

Model Training

To train a GMDA model:

from gmda.models import GMDARunner
from gmda.models.gmda.tools import get_config

# Load configuration
config = get_config('path/to/config.json')

# Initialize and train the model
model = GMDARunner(config)
model.train(train_loader, val_loader, X, config['training'])

Generating Synthetic Data

From a Trained Model:

X_synthetic, y_synthetic = model.generate(y)
X_synthetic, y_synthetic = X_synthetic.numpy(), y_synthetic.numpy()

From a Pretrained Model:

from gmda.models import generate_from_pretrained

X_synthetic, y_synthetic = generate_from_pretrained(
    y, 
    config['model'], 
    path_pretrained=model.checkpoint_dir,
    device=config['model']['device'], 
    return_as_array=True
)

Metrics

GMDA provides metrics to evaluate the quality of generated data:

from gmda.metrics import get_corr_error, get_precision_recall
import numpy as np

# Correlation Error
idx = np.random.choice(np.arange(len(X)), size=min(len(X), 1500), replace=False)
corr_error, corr_error_matrix = get_corr_error(X[idx], X_synthetic[idx])

# Precision/Recall
precision, recall = get_precision_recall(X, X_synthetic, nb_nn=config['training']['nb_nn_for_prec_recall'])

Command-Line Usage

GMDA can be run from the command line:

python main.py --dataset '<DATASET>' \
               --path_train '<PATH/TO/TRAIN/CSV>' \
               --path_test '<PATH/TO/TEST/CSV>' \
               --device 'cuda:0' \
               --config '<PATH/TO/CONFIG/JSON>' \
               --output_dir '<PATH/TO/OUTPUT/RESULTS>' \
               --compute_metrics \
               --save_generated

For more details on command-line options, run:

python main.py --help

Contributing

We welcome contributions! Please contact me for more details.

License

This project is licensed under the MIT License.

gmda's People

Contributors

ablacan avatar

Stargazers

Matthieu Nastorg avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.