Giter Club home page Giter Club logo

animl's Introduction

animl

This is the start of a python machine learning library to augment scikit-learn. At the moment, all we have is functionality for decision tree visualization and model interpretation.

(So far, we've only tested this on OS X.) To install (Python >=3.6 only), do this:

pip install animl

and you need the following tools for the decision tree visualizations to work:

brew install poppler
brew install pdf2svg
brew install graphviz --with-librsvg --with-app --with-pango

Please email us with notes on making it work on other platforms. thanks!

Decision tree visualization

By Terence Parr and Prince Grover

See How to visualize decision trees for deeper discussion of our decision tree visualization library and the visual design decisions we made.

Decision trees are the fundamental building block of gradient boosting machines and Random Forests(tm), probably the two most popular machine learning models for structured data. Visualizing decision trees is a tremendous aid when learning how these models work and when interpreting models. Unfortunately, current visualization packages are rudimentary and not immediately helpful to the novice. For example, we couldn't find a library that visualizes how decision nodes split up the feature space. It is also uncommon for libraries to support visualizing a specific feature vector as it weaves down through a tree's decision nodes; we could only find one image showing this.

So, we've created a general package for scikit-learn decision tree visualization and model interpretation, which we'll be using heavily in an upcoming machine learning book (written with Jeremy Howard).

The visualizations are inspired by an educational animiation by R2D3; A visual introduction to machine learning. With animl, you can visualize how the feature space is split up at decision nodes, how the training samples get ditributed in leaf nodes and how the tree makes predictions for a specific observation. These operations are critical to for understanding how classfication or regression decision trees work. If you're not familiar with decision trees, check out fast.ai's Introduction to Machine Learning for Coders MOOC.

Usage

dtree: Main function to create decision tree visualization. Given a decision tree regressor or classifier, creates and returns a tree visualization using the graphviz (DOT) language.

  • Required libraries:
    Basic libraries and imports that will (might) be needed to generate the sample visualizations shown in examples below.
from sklearn.datasets import *
from sklearn import tree
from animl.viz.trees import dtreeviz
from animl.trees import *
import graphviz
  • Regression decision tree:
    The default orientation of tree is top down but you can change it to left to right using orientation="LR". view() gives a pop up window with rendered graphviz object.
regr = tree.DecisionTreeRegressor(max_depth=2)
boston = load_boston()
regr.fit(boston.data, boston.target)

viz = dtreeviz(regr,
               boston.data,
               boston.target,
               target_name='price',
               feature_names=boston.feature_names)
              
viz.view()              

  • Classification decision tree:
    An additional argument of class_names giving a mapping of class value with class name is required for classification trees.
classifier = tree.DecisionTreeClassifier(max_depth=2)  # limit depth of tree
iris = load_iris()
classifier.fit(iris.data, iris.target)

viz = dtreeviz(classifier, 
               iris.data, 
               iris.target,
               target_name='variety',
              feature_names=iris.feature_names, 
               class_names=["setosa", "versicolor", "virginica"]  # need class_names for classifier
              )  
              
viz.view() 

  • Prediction path:
    Highlights the decision nodes in which the feature value of single observation passed in argument X falls. Gives feature values of the observation and highlights features which are used by tree to traverse path.
regr = tree.DecisionTreeRegressor(max_depth=2)  # limit depth of tree
diabetes = load_diabetes()
regr.fit(diabetes.data, diabetes.target)
X = diabetes.data[np.random.randint(0, len(diabetes.data)),:]  # random sample from training

viz = dtreeviz(regr,
               diabetes.data, 
               diabetes.target, 
               target_name='value', 
               orientation ='LR',  # left-right orientation
               feature_names=diabetes.feature_names,
               X=X)  # need to give single observation for prediction
              
viz.view()  

  • Decision tree without scatterplot or histograms for decision nodes:
    Simple tree without histograms or scatterplots for decision nodes. Use argument fancy=False
classifier = tree.DecisionTreeClassifier(max_depth=4)  # limit depth of tree
cancer = load_breast_cancer()
classifier.fit(cancer.data, cancer.target)

viz = dtreeviz(classifier,
              cancer.data,
              cancer.target,
              target_name='cancer',
              feature_names=cancer.feature_names, 
              class_names=["malignant", "benign"],
              fancy=False )  # fance=False to remove histograms/scatterpots from decision nodes
              
viz.view() 

For more examples and different implementations, please see the jupyter notebook full of examples.

Implementation guidelines

At least on the mac, make sure to install using:

brew install poppler
brew install pdf2svg
brew install graphviz --with-librsvg --with-app --with-pango

Then use setup.py to make sure the library gets installed properly

python setup.py install -f

This will push the animl library to your local egg cache. E.g., on Terence's box, it add /Users/parrt/anaconda3/lib/python3.6/site-packages/animl-0.1-py3.6.egg.

animl's People

Contributors

groverpr avatar parrt avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.