Giter Club home page Giter Club logo

iterative-random-forest's Introduction

Binder

iterative Random Forest

The algorithm details are available at:

Sumanta Basu, Karl Kumbier, James B. Brown, Bin Yu, Iterative Random Forests to detect predictive and stable high-order interactions, PNAS https://www.pnas.org/content/115/8/1943

The implementation is a joint effort of several people in UC Berkeley. See the Authors.md for the complete list. The weighted random forest implementation is based on the random forest source code and API design from scikit-learn, details can be found in API design for machine learning software: experiences from the scikit-learn project, Buitinck et al., 2013.. The setup file is based on the setup file from skgarden.

Installation

To install, simply run pip install irf. If you run into any issues, see installation help.

A simple demo

In order to use irf, you need to import it in python.

import numpy as np
from irf import irf_utils
from irf.ensemble import RandomForestClassifierWithWeights

Generate a simple data set with 2 features: 1st feature is a noise feature that has no power in predicting the labels, the 2nd feature determines the label perfectly:

n_samples = 1000
n_features = 10
X_train = np.random.uniform(low=0, high=1, size=(n_samples, n_features))
y_train = np.random.choice([0, 1], size=(n_samples,), p=[.5, .5])
X_test = np.random.uniform(low=0, high=1, size=(n_samples, n_features))
y_test = np.random.choice([0, 1], size=(n_samples,), p=[.5, .5])
# The second feature (which is indexed by 1) is very important
X_train[:, 1] = X_train[:, 1] + y_train
X_test[:, 1] = X_test[:, 1] + y_test

Then run irf

all_rf_weights, all_K_iter_rf_data, \
    all_rf_bootstrap_output, all_rit_bootstrap_output, \
    stability_score = irf_utils.run_iRF(X_train=X_train,
                                        X_test=X_test,
                                        y_train=y_train,
                                        y_test=y_test,
                                        K=5,                          # number of iteration
                                        rf = RandomForestClassifierWithWeights(n_estimators=20),
                                        B=30,
                                        random_state_classifier=2018, # random seed
                                        propn_n_samples=.2,
                                        bin_class_type=1,
                                        M=20,
                                        max_depth=5,
                                        noisy_split=False,
                                        num_splits=2,
                                        n_estimators_bootstrap=5)

all_rf_weights stores all the weights for each iteration:

print(all_rf_weights['rf_weight5'])

The proposed feature combination and their scores:

print(stability_score)

iterative-random-forest's People

Contributors

ogrisel avatar amueller avatar gaelvaroquaux avatar larsmans avatar agramfort avatar glouppe avatar pprett avatar mblondel avatar vene avatar arjoly avatar jnothman avatar jaquesgrobler avatar jakevdp avatar mechcoder avatar nellev avatar robertlayton avatar bdholt1 avatar ndawe avatar raghavrv avatar ahojnnes avatar lesteve avatar weilinear avatar clayw avatar kemaleren avatar oddskool avatar alexanderfabisch avatar alextp avatar bthirion avatar shifwang avatar dsullivan7 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.