Comments (8)
We support different post transformations. So maybe by setting POST_TRANSFORM
to None
it will remove the normalization step.
from hummingbird.
Hi @louis-huang! Hummingbird supports multiclass for xgboost. You can check a test here.
Hummingbird does not support categorical values for tree models. Can it be that the dataset has categorical features?
from hummingbird.
Hi @interesaaat , I want to try multi-label not multi-class model. This is an example from xgb: https://xgboost.readthedocs.io/en/stable/tutorials/multioutput.html
from hummingbird.
@louis-huang The tree representations for a multi-class model vs. a multi-label classification model are essentially the same, with a separate tree ensemble for each class. So the Hummingbird for multi-class should work in theory. Some preprocessing may be necessary.
Internally, XGBoost builds one model for each target similar to sklearn meta estimators, with the added benefit of reusing data and other integrated features like SHAP. .
from hummingbird.
Thanks @gorkemozkaya for chiming in. We already have some infra for multi-label for sklearn multioutput regression. Shouldn't be too hard I feel to make it work for xgboost as well. Contributions welcome! 😄
from hummingbird.
@interesaaat @louis-huang I think the problem is that Hummingbird is normalizing (i.e. applying softmax to) each row output so that the probabilities will add up to 1. If we could remove that normalization it would support multi-label. For now, I propose a workaround, by separating the n-way multi-label classifier into n separate binary classifiers, and the output probabilities match: notebook link
from hummingbird.
Thanks @interesaaat, this was helpful! I verified that the outputs match if we override the default post-transform by passing the extra_config = {'post_transform': 'LOGISTIC'}
argument, and then take the last n_classes
columns of the output.
By default it is using the SOFTMAX
post_transform, which is not the right transform for multi-label. The library needs to decide the transform based on the objective
attribute of the XGBoost classification model. I.e it should be changed such that multi:softprob
maps to SOFTMAX
, whereas binary:logistic
maps to LOGISTIC
. But it needs a slightly different version of LOGISTIC
that does not double the number of output columns. For now, we can just take the last n_classes
columns.
from hummingbird.
Thanks for pointing out the post_transform @interesaaat This helps us a lot! Thank you @gorkemozkaya to verify this and providing suggestions!
from hummingbird.
Related Issues (20)
- New LGBM Version 4.0.0 changes HOT 2
- TVM + Mac HOT 3
- Performing simple inference HOT 3
- XGBoost 2.0.0 breaks tests HOT 8
- Example cases for DecisionTreeClassifier HOT 2
- onnxruntime==1.16.0 release breaks tests HOT 3
- if i have saved a pytorch_based model, i want to run inference on cpu, how to change the codes? HOT 5
- Adding more ONNX data types
- ONNXMLGather breaks tests HOT 2
- Support for sklearn training HOT 1
- Github Actions changes
- SKL 1.4 release breaks build HOT 3
- coredumps on Ubuntu pipeline runs HOT 2
- Documentation deploy problem
- Need to update setup.py description
- Is "Hummingbird" trademarked? HOT 3
- [Tracking Issue] Upgrade Apache TVM to v0.16.0
- onnxruntime.capi.onnxruntime_pybind11_state.Fail ONNXRuntimeError HOT 2
- FAILED tests/test_sklearn_bagging.py::TestSklearnBagging::test_bagging_logistic_regression - AssertionError HOT 2
- skl2onnx==1.17.0 breaks pipeline HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hummingbird.