Giter Club home page Giter Club logo

Comments (13)

annawoodard avatar annawoodard commented on August 23, 2024 4

I'm using pycox for time-to-event prediction with medical imaging. Thank you for providing such a well-documented and easy-to-use library! I'm using pytorch lightning + the refactor_out_torchtuples branch. I wanted to add a vote for this issue-- it would be great if the refactored implementation could eventually be merged into master.

from pycox.

havakv avatar havakv commented on August 23, 2024 2

@petteriTeikari in this branch refactor_out_torchtuples I have a minimal example of how to run the LogisticHazard model without torctuples. However, you need to check out the branch to run the example as it required some refactoring of the logistic_hazard code.

Does this work for you, or do you think there should be a different design?

from pycox.

havakv avatar havakv commented on August 23, 2024 2

Thank you for addressing this @annawoodard. It's easier to prioritize an issues when there are more people that are interested in it. It takes some time to get it right, but I'll try to spend some time on this going forward. In #102 I'm working on refactoring the PMF, MTLR and DeepHit.

from pycox.

havakv avatar havakv commented on August 23, 2024

Thank you for the feedback and the kind words, and I completely agree with you.
We have in some sense sacrificed some of the ease of further researching the models for a (hopefully) simpler user experience. By using torchtuples, it is simpler to access callbacks such as learning rate schedulers and model checkpointing, while enabling the user to work with torch.tensors, numpy.arrays or lists/tuples of tensors or arrays.

Basing pycox on torchtuples makes it harder to understand and extend the models (especially CoxCC and CoxTime which are quite convoluted).
torchtuples essentially started as a simple training loop and prediction function (as we were tired of rewriting them), but ended up being a quite significant part of pycox. That might have been a poor choice, but it allowed us to easier conduct larger experiments while researching the models.

While torchtuples role for pycox is still, essentially, a training loop and a prediction function (just turns off gradients and randomness), making torchtuples as flexible as it is has made the code harder to understand.
However, with this knowledge in mind, it should be possible to write training/prediction loops for the pycox models, without relying on torchtuples (except for CoxCC and CoxTime which are more dependent on torchtuples). What you need is just the right loss function found in pycox.models.loss (which are written as a regular pytorch losses, such as MSELoss). For survival predictions one can either use the model class (e.g., LogisticHazard) with your pretrained nn.Module net, or copy the logic of the models predict_surv. Here I agree that it could be beneficial to hava a standalone function predict_surv outside the class that contains the prediction logic and just relies on the output tensor of the network.

So I think that by writing standalone prediction functions (not part of the classes), it should be straight forward to use the pycox models without any code from torchtuples. Maybe I should write an example notebook that illustrates this... Do you think that would suffice?
I could potentially also create examples of CoxCC and CoxTime that does not depend on torchtuples, just for illustrative purposes, although that would require some more work.

torchtuples was never intended to limit the pycox model, but rather to simplify the usage. However, attempting to create a simple-to-use api, without restricting too much of the extensibility, and making the code easy to understand, wasn't as easy as I though...

from pycox.

CielAl avatar CielAl commented on August 23, 2024

I am still reading the codes. Actually lots of the protected methods like model.predict already takes care of the case of dataloader input, and mostly it is just the highest level of the API that might need a few tweaks.
There aren't many fancy pytorch-based survival analysis packages like yours here, and I will be very exciting when this lib can be applied to many more pipelines. :)

from pycox.

havakv avatar havakv commented on August 23, 2024

Feel free to propose any suggested changes. All contributions are welcome :)

from pycox.

petteriTeikari avatar petteriTeikari commented on August 23, 2024

So I think that by writing standalone prediction functions (not part of the classes), it should be straight forward to use the pycox models without any code from torchtuples. Maybe I should write an example notebook that illustrates this... Do you think that would suffice?

Thanks a lot for the nice repo @havakv and was curious if you had any time for non-torchtuples example as it complicates a bit things when trying to integrate your approach to existing codebase. I have been using the https://github.com/Project-MONAI/MONAI for our volumetric medical images (e.g. the DRTOP , and wanted to see if this example 04_mnist_dataloaders_cnn.ipynb would work for our data

The basic mechanics go through just fine and I get an initial model trained just fine, but then like using some MONAI dataloaders, MONAI augmentations, extra augmentation libraries, accessing activations from different layers, etc., requires now tweaking. And MONAI itself does not add any complexity really to "normal" PyTorch model training logic with the following structure:

for epoch in epoch_idxs:

        # TRAIN
        for batch_data in train_loader:
                "training_script"

        # VALIDATION
        for batch_data in val_loader:
                "val_script"     

        ## TEST
        for batch_data in test_loader:
                "test_script"     

I understand though if you are busy with other stuff :)

from pycox.

havakv avatar havakv commented on August 23, 2024

Thank you for the kind words @petteriTeikari! I agree that torchtuples complicates use of external frameworks, so pycox could really benefit from being usable with vanilla pytorch training loops. I'm not sure how much work this would be, but it's probably time for me to just get started on it. I probably requires some refactoring of the individual models, so are there any models in particular that you're interested in?

For your problem in particular, would it suffice to create a version of 01_introduction.ipynb that works with a standard pytorch training loop?

from pycox.

petteriTeikari avatar petteriTeikari commented on August 23, 2024

@havakv Sounds good if you have the time for that, and I was initially interested in the basic LogisticHazard(), and I pretty much followed that 04_mnist_dataloaders_cnn.ipynb example as I have 3D images that I have the survival labels as well.

And as for the same dataset I have already dense labels for segmentation, classification labels, it would a lot easier to integrate your nice work on survival models and losses to that and not the other way around. And I assume a lot of other researchers to be in similar situation?

from pycox.

petteriTeikari avatar petteriTeikari commented on August 23, 2024

@havakv and if it is possible to make requests with some urgency emphasis. Is there an easy method that allows one to extract the CNN activations from the model?

i.e.

image

that feature vector for my input volumes, as for example that is used then in a tree with other non-imaging features?

image
https://doi.org/10.1038/s41598-020-69106-8

from pycox.

havakv avatar havakv commented on August 23, 2024

@petteriTeikari I'm not sure I understand your question on the extraction of CNN activations. But if I understand you correctly and you just want the output of a different layer in the network, that you can just add a new function to the network that does this. The example below is from this tutorial and I've added a function conv2_output that can be called to look at the output of the second convolution.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def conv2_output(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.conv2(x)
        return x

Does this answer your question?

from pycox.

petteriTeikari avatar petteriTeikari commented on August 23, 2024

@havakv Yes I think that is what I meant more or less https://discuss.pytorch.org/t/how-can-i-extract-intermediate-layer-output-from-loaded-cnn-model/77301. When I have first trained the model, I would be interested in having the intermediate layer activations instead of the final output. So in your example net the output of fc2 so that I can get that vector with 84 values for each of the volume in my dataset in addition to the actual survival curve output

Thanks for the minimal example, I will have a better look of that later :)

from pycox.

havakv avatar havakv commented on August 23, 2024

Yes, that's probably simpler, as you don't have to rewrite the forward function for each layer you want to investigate.

from pycox.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.