Giter Club home page Giter Club logo

Comments (12)

rohanshad avatar rohanshad commented on July 21, 2024 2

#66 Here you go ^

from pycox.

havakv avatar havakv commented on July 21, 2024 1

I think the Cox models will be a bit harder to make work (though CoxPH is likely the simplest). Currently computations of the non-parametric baseline hazards are part of the CoxPH class. Probably need to factor that out in a similar way as I did for the logistic-hazard. But you are of course more than welcome to give it a go!

Right now I have too much to do between work and revisjons, so I cant prioritise this (I imagine it's quite some work), but I'll get started as soon as I get the time.

from pycox.

havakv avatar havakv commented on July 21, 2024

Thank you for the kind words!
I think pycox is way to dependent on torchtuples (which is quite limited) and this have been discussed in #25.
Pycox could really benefit from working pytoch-lightning, but I don't think there needs to be any special integration, just a way to decouple pycox from torchtuples. One can then give examples of how to fit models with pytorch-lightning.

In relation to #25 I made this example in this branch where I propose a change to the LogisticHazard model so that it can can be fitted with just vanilla pytorch (no torchtuples stuff). Could you take a look at that and see if you're able to make it work with pytorch-lightning?

It would really help with some feedback on these changes before I start refactoring all the models :)

from pycox.

rohanshad avatar rohanshad commented on July 21, 2024

Perfect!

I'll take a crack at this soon (hopefully within this week) and circle back. Thanks πŸ‘

from pycox.

rohanshad avatar rohanshad commented on July 21, 2024

Got it all working on a new conda environment, and I've successfully ported the example to torch lightning. I set up the dataset within a lightning DataModule and packaged the pre-processing functions there too. Since each model may require slightly different pre-processing steps, it might make sense to define all those preprocessing functions within the dataset module itself. I setup the data for train / test here too.

The model, train logic, metrics, loss, optimizer all fit in a surv_model LightningModule. The trainer function trains it all and spits out a progress bar, tracks experiment versions, and can dump logs to csv / tensorboard as required:

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type                 | Params
---------------------------------------------------
0 | net       | Sequential           | 1 K   
1 | loss_func | NLLLogistiHazardLoss | 0     
Epoch 19: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 6/6 [00:01<00:00,  4.23it/s, loss=2.269, v_num=34, loss_step=2.14, loss_epoch=2.26]
Running in Evaluation Mode...
Concordance: 0.6252826147316648

The only think that I keep vanilla pytorch is the testing phase since the metrics are calculated directly on a pandas dataframe obviating the need for a DataLoader. Let me know if you'd want me to open a PR on that branch so you can see what this looks like.

from pycox.

havakv avatar havakv commented on July 21, 2024

Great work @rohanshad!
Sure you can open a PR! I'ts much simpler do discuss when we have some concrete examples.

from pycox.

rohanshad avatar rohanshad commented on July 21, 2024

Let me know if you'd like to create an example for CoxPH, the workings and estimators seem to be a bit different from the logistic_hazards models. I can carry on from there and attempt to make a flexible-ish lightning module that works with coxPH too.

from pycox.

havakv avatar havakv commented on July 21, 2024

I guess this can stay open until all of pycox can be use with pytroch-lightning

from pycox.

havakv avatar havakv commented on July 21, 2024

@rohanshad If you wan't to take a crack at CoxPH in pytorch-lighting, I've now made a refactored version of CoxPH in https://github.com/havakv/pycox/blob/refactor_out_torchtuples/pycox/models/coxph.py that should be straight forward to use. There's missing some docs and tests, but I'll add that later.

By using compute_cumulative_baseline_hazards and output2surv, I think it shouldn't be too much work. Let me know if you run into any issues!

from pycox.

yorickvanzweeden avatar yorickvanzweeden commented on July 21, 2024

@havakv Thank you for the refactorings of CoxPH and LogisticHazard. Do you have any plans to refactor PC-Hazard? Or should I be able to use them like the other two?

I am currently doing this. Yet, compared with the LogisticHazard and CoxPH, I am not getting great performance.

import pytorch_lightning as pl
import pandas as pd
import torch
import torch.nn.functional as F

from pycox.models.loss import nll_pc_hazard_loss
from pycox.evaluation import EvalSurv
from pycox.models.utils import pad_col, make_subgrid

class DummyModel(pl.LightningModule):
    def __init__(self,  duration_index=None):

        super().__init__()

        self.net = SomeModel
        self.loss_func = nll_pc_hazard_loss
        self.duration_index = duration_index
        
    def forward(self, x):
        return self.net.forward(x)

    def common_step(self, batch, batch_idx, stage):
        x, duration, event, interval = batch
        preds = self(x)
        loss = self.loss_func(preds, duration, event, interval)

        if stage == "train":
            return {"loss": loss}
        else:
            return {"loss": loss, "preds": preds, "event": event, "duration": duration}

    def training_step(self, batch, batch_idx):
        return self.common_step(batch, batch_idx, 'train')

    def training_epoch_end(self, outs):
        self.logger.experiment.add_scalar("loss/train", torch.mean(torch.stack([x['loss'] for x in outs])), current_epoch)

    def validation_step(self, batch, batch_idx):
        return self.common_step(batch, batch_idx, 'val')

    def validation_epoch_end(self, outs):
          self.logger.experiment.add_scalar("loss/val", torch.mean(torch.stack([x['loss'] for x in outs])), current_epoch)

          predictions = torch.vstack([x['preds'] for x in outs])
          durations = torch.vstack([x['duration'] for x in outs])
          events = torch.vstack([x['event'] for x in outs])

          surv_df = self.predict_surv_df(predictions, sub=10, duration_index=self.duration_index)
          ev = EvalSurv(surv_df, durations.cpu().numpy().reshape(-1, ), events.cpu().numpy().reshape(-1, ))
          self.logger.experiment.add_scalar("val_auroc", ev.concordance_td(), current_epoch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=SomeLearningRate)
        return optimizer

    def predict_surv_df(self, preds, sub, duration_index):
        n = preds.shape[0]
        hazard = F.softplus(preds).view(-1, 1).repeat(1, sub).view(n, -1).div(sub)
        hazard = pad_col(hazard, where='start')
        surv = hazard.cumsum(1).mul(-1).exp()
        surv = surv.cpu().numpy()

        index = None
        if duration_index is not None:
            index = make_subgrid(duration_index, sub)
        return pd.DataFrame(surv.transpose(), index)

The DataLoader is aligned with the PC-Hazard notebook. In this way, the duration_index corresponds to PCHazard.label_transform(num_durations).cuts

from pycox.

havakv avatar havakv commented on July 21, 2024

Hi @yorickvanzweeden.
I should refactor PCHazard in the same way as for LogisticHazard (just struggling to find the time).

From what I can see, your code should work, so I don't really know why you're not getting the results you want. Have you tried comparing these results with

model = PCHazard(net, optimizer, duration_index=labtrans.cuts)
model.fit(...)
mode.predict_surv_df(...)

to check if it is just the PCHazard that doesn't perform well, or if there is something with you implementation?

from pycox.

yorickvanzweeden avatar yorickvanzweeden commented on July 21, 2024

Thanks @havakv for your reply. I suspect it is due to the difficulty of the problem in combination with hyperparameters that have yet to be optimized.

from pycox.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.