Giter Club home page Giter Club logo

Comments (5)

GCBallesteros avatar GCBallesteros commented on August 23, 2024 3

You can find the code I was using on this gist.

At this stage you probably can't change your APIs to make them sklearn compatible and even if you wanted to, the flexibility of the underlying NN makes it very hard to pin down an interface. I can think right now of three approaches:

  1. Uploading an example, better than the one I've uploaded, and let people cook up their own solution from the template.
  2. Take the route of Keras. They provide a function to wrap around the model building that returns a sklearn compatible estimator.
  3. If the model is simple enough, like the ones produced via MLPVanilla in torchtuples then that can be fully wrapped. Reducing neural networks to MLPs is also the approach taken by sklearn itself. Of course, any combination of the approaches described above can co-exist with the existing code without risking breaking anything for current users.

As for what to do with mixed y vectors, I'm not sure if sklearn will like structured arrays. I will try it and edit this message.

from pycox.

havakv avatar havakv commented on August 23, 2024

Hi Afshin,
Sadly none of the examples include any form of hyperparameter tuning. The examples only use a single set of hyperparameters, so you have to set up the hyperparameter search yourself. If you use random or grid search, this is quite straight forward, but it can be quite complicated for more advanced hyperparameter searches.

For random and grid search, you need a function that takers a set of hyperparameters, fits your model, and gives some score (e.g., the concordance index). You can then run a loop with your desired hyperparameters to see which configuration works the best.

You can alternatively look at the DeepSurv code by the original author https://github.com/jaredleekatzman/DeepSurv to see if what you're looking for is there.

from pycox.

havakv avatar havakv commented on August 23, 2024

Should probably add some example for how to approach hyperparameter tuning of various models. At least a notebook that gives a general approach.

If integration with scikit-learn pipelines would be possible that would be ideal, but if not, we should at the minimum show an example approach.

from pycox.

GCBallesteros avatar GCBallesteros commented on August 23, 2024

I have implemented this building a wrapper around DeepSURV to make interface nicely with sklearn. It boils down to following the instructions in Developing scikit-learn estimators. You only need to implement fit and score methods.

The only "tricky" part is to that pycox models have targets consisting of multiple arrays which sklearn won't accept. The workaround is to stack them into a matrix. For DeepSurv that would be a matrix of size (n_examples, 2). The first column corresponds to time-to-event and the second one is the event indicator.

from pycox.

havakv avatar havakv commented on August 23, 2024

Nice @GCBallesteros! I hope you will be able to share this with us!

Concerning the targets, we sometimes need specific data types, which might cause some issues when simply stacking them into a matrix, but we can maybe cross that bridge when we get there.

Alternatively we could take the same approach as scikit-survival and represent the targets with structured arrays (like this). As scikit-survival works very well with scikit-learn, they probably have put some work into this.

from pycox.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.