Giter Club home page Giter Club logo

Comments (8)

interesaaat avatar interesaaat commented on June 4, 2024 1

We support different post transformations. So maybe by setting POST_TRANSFORM to None it will remove the normalization step.

from hummingbird.

interesaaat avatar interesaaat commented on June 4, 2024

Hi @louis-huang! Hummingbird supports multiclass for xgboost. You can check a test here.

Hummingbird does not support categorical values for tree models. Can it be that the dataset has categorical features?

from hummingbird.

louis-huang avatar louis-huang commented on June 4, 2024

Hi @interesaaat , I want to try multi-label not multi-class model. This is an example from xgb: https://xgboost.readthedocs.io/en/stable/tutorials/multioutput.html

from hummingbird.

gorkemozkaya avatar gorkemozkaya commented on June 4, 2024

@louis-huang The tree representations for a multi-class model vs. a multi-label classification model are essentially the same, with a separate tree ensemble for each class. So the Hummingbird for multi-class should work in theory. Some preprocessing may be necessary.
Internally, XGBoost builds one model for each target similar to sklearn meta estimators, with the added benefit of reusing data and other integrated features like SHAP. .

from hummingbird.

interesaaat avatar interesaaat commented on June 4, 2024

Thanks @gorkemozkaya for chiming in. We already have some infra for multi-label for sklearn multioutput regression. Shouldn't be too hard I feel to make it work for xgboost as well. Contributions welcome! 😄

from hummingbird.

gorkemozkaya avatar gorkemozkaya commented on June 4, 2024

@interesaaat @louis-huang I think the problem is that Hummingbird is normalizing (i.e. applying softmax to) each row output so that the probabilities will add up to 1. If we could remove that normalization it would support multi-label. For now, I propose a workaround, by separating the n-way multi-label classifier into n separate binary classifiers, and the output probabilities match: notebook link

from hummingbird.

gorkemozkaya avatar gorkemozkaya commented on June 4, 2024

Thanks @interesaaat, this was helpful! I verified that the outputs match if we override the default post-transform by passing the extra_config = {'post_transform': 'LOGISTIC'} argument, and then take the last n_classes columns of the output.

By default it is using the SOFTMAX post_transform, which is not the right transform for multi-label. The library needs to decide the transform based on the objective attribute of the XGBoost classification model. I.e it should be changed such that multi:softprob maps to SOFTMAX, whereas binary:logistic maps to LOGISTIC. But it needs a slightly different version of LOGISTIC that does not double the number of output columns. For now, we can just take the last n_classes columns.

image

from hummingbird.

louis-huang avatar louis-huang commented on June 4, 2024

Thanks for pointing out the post_transform @interesaaat This helps us a lot! Thank you @gorkemozkaya to verify this and providing suggestions!

from hummingbird.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.