Comments (7)
@blokhin Is it possible to post your model for debugging purposes? If not, can you make a toy example to demonstrate the issue?
from treelite.
@blokhin Any updates?
from treelite.
@hcho3 I figured out the compiled random forest binary classifier outputs not the 0 vs. 1 (as expected), but the float in the range between 0 and 1. Not a big deal, as this float can be correctly rounded then. That was actually a reason of discrepancy.
from treelite.
@blokhin Treelite will produce a number between 0 and 1, representing the fraction of the votes among the decision trees for the positive class (e.g. 0.75 means 75% of the trees predicted the positive class).
from treelite.
This is a sample code to reproduce:
import random
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import treelite.gallery.sklearn
import treelite.runtime
# First, generate sample data
my_range = range(1, 100)
X_data = [[random.choice(my_range) for _ in range(125)] for _ in range(76)]
X_test = X_data.pop()
y_data = [random.choice([0, 1]) for _ in range(75)]
# Second, prepare a classifier and compile it
model_py = RandomForestClassifier(
n_estimators=100,
max_features=2,
max_depth=None,
min_samples_split=2,
min_samples_leaf=5,
bootstrap=True,
n_jobs=-1
)
model_py.fit(X_data, y_data)
model_file = './compiled.so'
icompiler = treelite.gallery.sklearn.import_model(model_py)
icompiler.export_lib(toolchain='clang', libpath=model_file, verbose=True, params={'parallel_comp': 8})
model_tr = treelite.runtime.Predictor(model_file, verbose=True)
# Third, compare: py vs. tr
result_py = model_py.predict([X_test])[0]
batch = treelite.runtime.Batch.from_npy2d(np.array([X_test]))
result_tr = model_tr.predict(batch)
print(result_py, result_tr)
assert result_py == result_tr # I would expect that, but...
from treelite.
@blokhin Also, try using predict_proba()
to get probability values from scikit-learn
from treelite.
OK, got it!
from treelite.
Related Issues (20)
- treelite::ConcatenateModelObjects() ought to set threshold_type and leaf_output_type fields
- Clean up serialization logic
- Support XGBoost gblinear Booster HOT 1
- Release version 3.3.0
- Release version 3.4.0
- Replace setup.py with pyproject.toml
- Treelite crashes with XGBoost 2.0 dev
- Document Treelite serialization format.
- Adopt Four-Document System to organize docs
- Refactor sklearn loader using mix-in classes
- Implement v4 serialization format
- Revamp JSON importer to make it easy to use
- Drop "max_index" postprocessor
- Add directory exist check in _load_lib for add_dll_directory HOT 1
- scikit-learn 1.4.0 breaks unit test HOT 1
- Multi-class, multi-output RandomForestClassifier in scikit-learn produces error
- Test multi-target models where each tree predicts multiple regression targets HOT 1
- Export model back to scikit-learn HOT 1
- NumPy 2.0 support HOT 4
- xgboost json model fails to load HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from treelite.