Comments (13)
I'm using pycox for time-to-event prediction with medical imaging. Thank you for providing such a well-documented and easy-to-use library! I'm using pytorch lightning + the refactor_out_torchtuples branch. I wanted to add a vote for this issue-- it would be great if the refactored implementation could eventually be merged into master.
from pycox.
@petteriTeikari in this branch refactor_out_torchtuples I have a minimal example of how to run the LogisticHazard model without torctuples. However, you need to check out the branch to run the example as it required some refactoring of the logistic_hazard code.
Does this work for you, or do you think there should be a different design?
from pycox.
Thank you for addressing this @annawoodard. It's easier to prioritize an issues when there are more people that are interested in it. It takes some time to get it right, but I'll try to spend some time on this going forward. In #102 I'm working on refactoring the PMF, MTLR and DeepHit.
from pycox.
Thank you for the feedback and the kind words, and I completely agree with you.
We have in some sense sacrificed some of the ease of further researching the models for a (hopefully) simpler user experience. By using torchtuples, it is simpler to access callbacks such as learning rate schedulers and model checkpointing, while enabling the user to work with torch.tensors, numpy.arrays or lists/tuples of tensors or arrays.
Basing pycox on torchtuples makes it harder to understand and extend the models (especially CoxCC and CoxTime which are quite convoluted).
torchtuples essentially started as a simple training loop and prediction function (as we were tired of rewriting them), but ended up being a quite significant part of pycox. That might have been a poor choice, but it allowed us to easier conduct larger experiments while researching the models.
While torchtuples role for pycox is still, essentially, a training loop and a prediction function (just turns off gradients and randomness), making torchtuples as flexible as it is has made the code harder to understand.
However, with this knowledge in mind, it should be possible to write training/prediction loops for the pycox models, without relying on torchtuples (except for CoxCC and CoxTime which are more dependent on torchtuples). What you need is just the right loss function found in pycox.models.loss
(which are written as a regular pytorch losses, such as MSELoss). For survival predictions one can either use the model class (e.g., LogisticHazard) with your pretrained nn.Module
net, or copy the logic of the models predict_surv
. Here I agree that it could be beneficial to hava a standalone function predict_surv
outside the class that contains the prediction logic and just relies on the output tensor of the network.
So I think that by writing standalone prediction functions (not part of the classes), it should be straight forward to use the pycox models without any code from torchtuples. Maybe I should write an example notebook that illustrates this... Do you think that would suffice?
I could potentially also create examples of CoxCC and CoxTime that does not depend on torchtuples, just for illustrative purposes, although that would require some more work.
torchtuples was never intended to limit the pycox model, but rather to simplify the usage. However, attempting to create a simple-to-use api, without restricting too much of the extensibility, and making the code easy to understand, wasn't as easy as I though...
from pycox.
I am still reading the codes. Actually lots of the protected methods like model.predict
already takes care of the case of dataloader input, and mostly it is just the highest level of the API that might need a few tweaks.
There aren't many fancy pytorch-based survival analysis packages like yours here, and I will be very exciting when this lib can be applied to many more pipelines. :)
from pycox.
Feel free to propose any suggested changes. All contributions are welcome :)
from pycox.
So I think that by writing standalone prediction functions (not part of the classes), it should be straight forward to use the pycox models without any code from torchtuples. Maybe I should write an example notebook that illustrates this... Do you think that would suffice?
Thanks a lot for the nice repo @havakv and was curious if you had any time for non-torchtuples example as it complicates a bit things when trying to integrate your approach to existing codebase. I have been using the https://github.com/Project-MONAI/MONAI for our volumetric medical images (e.g. the DRTOP , and wanted to see if this example 04_mnist_dataloaders_cnn.ipynb would work for our data
The basic mechanics go through just fine and I get an initial model trained just fine, but then like using some MONAI dataloaders, MONAI augmentations, extra augmentation libraries, accessing activations from different layers, etc., requires now tweaking. And MONAI itself does not add any complexity really to "normal" PyTorch model training logic with the following structure:
for epoch in epoch_idxs:
# TRAIN
for batch_data in train_loader:
"training_script"
# VALIDATION
for batch_data in val_loader:
"val_script"
## TEST
for batch_data in test_loader:
"test_script"
I understand though if you are busy with other stuff :)
from pycox.
Thank you for the kind words @petteriTeikari! I agree that torchtuples complicates use of external frameworks, so pycox could really benefit from being usable with vanilla pytorch training loops. I'm not sure how much work this would be, but it's probably time for me to just get started on it. I probably requires some refactoring of the individual models, so are there any models in particular that you're interested in?
For your problem in particular, would it suffice to create a version of 01_introduction.ipynb that works with a standard pytorch training loop?
from pycox.
@havakv Sounds good if you have the time for that, and I was initially interested in the basic LogisticHazard()
, and I pretty much followed that 04_mnist_dataloaders_cnn.ipynb
example as I have 3D images that I have the survival labels as well.
And as for the same dataset I have already dense labels for segmentation, classification labels, it would a lot easier to integrate your nice work on survival models and losses to that and not the other way around. And I assume a lot of other researchers to be in similar situation?
from pycox.
@havakv and if it is possible to make requests with some urgency emphasis. Is there an easy method that allows one to extract the CNN activations from the model?
i.e.
that feature vector for my input volumes, as for example that is used then in a tree with other non-imaging features?
https://doi.org/10.1038/s41598-020-69106-8
from pycox.
@petteriTeikari I'm not sure I understand your question on the extraction of CNN activations. But if I understand you correctly and you just want the output of a different layer in the network, that you can just add a new function to the network that does this. The example below is from this tutorial and I've added a function conv2_output
that can be called to look at the output of the second convolution.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def conv2_output(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.conv2(x)
return x
Does this answer your question?
from pycox.
@havakv Yes I think that is what I meant more or less https://discuss.pytorch.org/t/how-can-i-extract-intermediate-layer-output-from-loaded-cnn-model/77301. When I have first trained the model, I would be interested in having the intermediate layer activations instead of the final output. So in your example net the output of fc2
so that I can get that vector with 84 values for each of the volume in my dataset in addition to the actual survival curve output
Thanks for the minimal example, I will have a better look of that later :)
from pycox.
Yes, that's probably simpler, as you don't have to rewrite the forward function for each layer you want to investigate.
from pycox.
Related Issues (20)
- L1 and L2 penalty coxph HOT 1
- AssertionError: assert durations.shape[0] == surv.shape[1] == surv_idx.shape[0] == events.shape[0]
- METABRIC Covariates Subset HOT 1
- AttributeError: 'Series' object has no attribute 'is_monotonic' HOT 18
- about hazard value! HOT 2
- Reproduction of the results in JMLR19 paper HOT 1
- Calculating Estimated Population Survival Curve HOT 4
- Some question about the result of deephit_competing_risks HOT 2
- AttributeError: 'DeepHitSingle' object has no attribute 'state_dict' HOT 1
- ValueError: cannot convert float NaN to integer HOT 1
- Softmax layer and residual connections in DeepHitSingle model HOT 1
- _initialization of _internal failed
- TypeError: forward() missing 1 required positional argument: 'events'
- ValueError: cannot convert float NaN to integer HOT 1
- A model to add
- Auto-encoder pycox implementation for 3D images instead of tabular data
- performance for ordinal categorical covariates
- what kind of model in pycox works for sequential patterns
- Newton-Raphson optimization
- [Installation] python setup.py egg_info did not run successfully HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pycox.