Comments (12)
#66 Here you go ^
from pycox.
I think the Cox models will be a bit harder to make work (though CoxPH is likely the simplest). Currently computations of the non-parametric baseline hazards are part of the CoxPH class. Probably need to factor that out in a similar way as I did for the logistic-hazard. But you are of course more than welcome to give it a go!
Right now I have too much to do between work and revisjons, so I cant prioritise this (I imagine it's quite some work), but I'll get started as soon as I get the time.
from pycox.
Thank you for the kind words!
I think pycox is way to dependent on torchtuples (which is quite limited) and this have been discussed in #25.
Pycox could really benefit from working pytoch-lightning, but I don't think there needs to be any special integration, just a way to decouple pycox from torchtuples. One can then give examples of how to fit models with pytorch-lightning.
In relation to #25 I made this example in this branch where I propose a change to the LogisticHazard model so that it can can be fitted with just vanilla pytorch (no torchtuples stuff). Could you take a look at that and see if you're able to make it work with pytorch-lightning?
It would really help with some feedback on these changes before I start refactoring all the models :)
from pycox.
Perfect!
I'll take a crack at this soon (hopefully within this week) and circle back. Thanks π
from pycox.
Got it all working on a new conda environment, and I've successfully ported the example to torch lightning. I set up the dataset within a lightning DataModule and packaged the pre-processing functions there too. Since each model may require slightly different pre-processing steps, it might make sense to define all those preprocessing functions within the dataset module itself. I setup the data for train / test here too.
The model, train logic, metrics, loss, optimizer all fit in a surv_model LightningModule. The trainer function trains it all and spits out a progress bar, tracks experiment versions, and can dump logs to csv / tensorboard as required:
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
| Name | Type | Params
---------------------------------------------------
0 | net | Sequential | 1 K
1 | loss_func | NLLLogistiHazardLoss | 0
Epoch 19: 100%|βββββββ| 6/6 [00:01<00:00, 4.23it/s, loss=2.269, v_num=34, loss_step=2.14, loss_epoch=2.26]
Running in Evaluation Mode...
Concordance: 0.6252826147316648
The only think that I keep vanilla pytorch is the testing phase since the metrics are calculated directly on a pandas dataframe obviating the need for a DataLoader. Let me know if you'd want me to open a PR on that branch so you can see what this looks like.
from pycox.
Great work @rohanshad!
Sure you can open a PR! I'ts much simpler do discuss when we have some concrete examples.
from pycox.
Let me know if you'd like to create an example for CoxPH, the workings and estimators seem to be a bit different from the logistic_hazards models. I can carry on from there and attempt to make a flexible-ish lightning module that works with coxPH too.
from pycox.
I guess this can stay open until all of pycox can be use with pytroch-lightning
from pycox.
@rohanshad If you wan't to take a crack at CoxPH in pytorch-lighting, I've now made a refactored version of CoxPH in https://github.com/havakv/pycox/blob/refactor_out_torchtuples/pycox/models/coxph.py that should be straight forward to use. There's missing some docs and tests, but I'll add that later.
By using compute_cumulative_baseline_hazards and output2surv, I think it shouldn't be too much work. Let me know if you run into any issues!
from pycox.
@havakv Thank you for the refactorings of CoxPH and LogisticHazard. Do you have any plans to refactor PC-Hazard? Or should I be able to use them like the other two?
I am currently doing this. Yet, compared with the LogisticHazard and CoxPH, I am not getting great performance.
import pytorch_lightning as pl
import pandas as pd
import torch
import torch.nn.functional as F
from pycox.models.loss import nll_pc_hazard_loss
from pycox.evaluation import EvalSurv
from pycox.models.utils import pad_col, make_subgrid
class DummyModel(pl.LightningModule):
def __init__(self, duration_index=None):
super().__init__()
self.net = SomeModel
self.loss_func = nll_pc_hazard_loss
self.duration_index = duration_index
def forward(self, x):
return self.net.forward(x)
def common_step(self, batch, batch_idx, stage):
x, duration, event, interval = batch
preds = self(x)
loss = self.loss_func(preds, duration, event, interval)
if stage == "train":
return {"loss": loss}
else:
return {"loss": loss, "preds": preds, "event": event, "duration": duration}
def training_step(self, batch, batch_idx):
return self.common_step(batch, batch_idx, 'train')
def training_epoch_end(self, outs):
self.logger.experiment.add_scalar("loss/train", torch.mean(torch.stack([x['loss'] for x in outs])), current_epoch)
def validation_step(self, batch, batch_idx):
return self.common_step(batch, batch_idx, 'val')
def validation_epoch_end(self, outs):
self.logger.experiment.add_scalar("loss/val", torch.mean(torch.stack([x['loss'] for x in outs])), current_epoch)
predictions = torch.vstack([x['preds'] for x in outs])
durations = torch.vstack([x['duration'] for x in outs])
events = torch.vstack([x['event'] for x in outs])
surv_df = self.predict_surv_df(predictions, sub=10, duration_index=self.duration_index)
ev = EvalSurv(surv_df, durations.cpu().numpy().reshape(-1, ), events.cpu().numpy().reshape(-1, ))
self.logger.experiment.add_scalar("val_auroc", ev.concordance_td(), current_epoch)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=SomeLearningRate)
return optimizer
def predict_surv_df(self, preds, sub, duration_index):
n = preds.shape[0]
hazard = F.softplus(preds).view(-1, 1).repeat(1, sub).view(n, -1).div(sub)
hazard = pad_col(hazard, where='start')
surv = hazard.cumsum(1).mul(-1).exp()
surv = surv.cpu().numpy()
index = None
if duration_index is not None:
index = make_subgrid(duration_index, sub)
return pd.DataFrame(surv.transpose(), index)
The DataLoader is aligned with the PC-Hazard notebook. In this way, the duration_index corresponds to PCHazard.label_transform(num_durations).cuts
from pycox.
Hi @yorickvanzweeden.
I should refactor PCHazard in the same way as for LogisticHazard (just struggling to find the time).
From what I can see, your code should work, so I don't really know why you're not getting the results you want. Have you tried comparing these results with
model = PCHazard(net, optimizer, duration_index=labtrans.cuts)
model.fit(...)
mode.predict_surv_df(...)
to check if it is just the PCHazard that doesn't perform well, or if there is something with you implementation?
from pycox.
Thanks @havakv for your reply. I suspect it is due to the difficulty of the problem in combination with hyperparameters that have yet to be optimized.
from pycox.
Related Issues (20)
- is_monotonic should be update to is_monotonic_increasing HOT 1
- Issue in function cox_time.py HOT 1
- L1 and L2 penalty coxph HOT 1
- AssertionErrorοΌ assert durations.shape[0] == surv.shape[1] == surv_idx.shape[0] == events.shape[0]
- METABRIC Covariates Subset HOT 1
- AttributeError: 'Series' object has no attribute 'is_monotonic' HOT 17
- about hazard value! HOT 2
- Reproduction of the results in JMLR19 paper HOT 1
- Calculating Estimated Population Survival Curve HOT 4
- Some question about the result of deephit_competing_risks HOT 2
- AttributeError: 'DeepHitSingle' object has no attribute 'state_dict' HOT 1
- ValueError: cannot convert float NaN to integer HOT 1
- Softmax layer and residual connections in DeepHitSingle model HOT 1
- _initialization of _internal failed
- TypeError: forward() missing 1 required positional argument: 'events'
- ValueError: cannot convert float NaN to integer HOT 1
- A model to add
- Auto-encoder pycox implementation for 3D images instead of tabular data
- performance for ordinal categorical covariates
- what kind of model in pycox works for sequential patterns
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pycox.