Giter Club home page Giter Club logo

censored's Introduction

censored a pixelated version of the parsnip logo with a black censoring bar

R-CMD-check Codecov test coverage Lifecycle: experimental R-CMD-check-hard

censored is a parsnip extension package which provides engines for various models for censored regression and survival analysis.

Installation

You can install the released version of censored from CRAN with:

install.packages("censored")

And the development version from GitHub with:

# install.packages("pak")
pak::pak("tidymodels/censored")

Available models, engines, and prediction types

censored provides engines for the models in the following table. For examples, please see Fitting and Predicting with censored.

The time to event can be predicted with type = "time", the survival probability with type = "survival", the linear predictor with type = "linear_pred", the quantiles of the event time distribution with type = "quantile", and the hazard with type = "hazard".

model engine time survival linear_pred raw quantile hazard
bag_tree rpart
boost_tree mboost
decision_tree rpart
decision_tree partykit
proportional_hazards survival
proportional_hazards glmnet
rand_forest partykit
rand_forest aorsf
survival_reg survival
survival_reg flexsurv
survival_reg flexsurvspline

Contributing

This project is released with a Contributor Code of Conduct. By contributing to this project, you agree to abide by its terms.

censored's People

Contributors

bcjaeger avatar davisvaughan avatar emilhvitfeldt avatar hfrick avatar juliasilge avatar mattwarkentin avatar simonpcouch avatar topepo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

censored's Issues

Improve function documentation

Most of the work so far has been focused on making sure the code runs as it should. The next step is to make sure that all the documentation is up to standard

Require strata info for predictions from stratified model

Predictions from a stratified model should error informatively if the new_data does not contain the strata information.

library(survival)
library(censored)
#> Loading required package: parsnip

cox_model <- proportional_hazards() %>% 
  set_engine("survival")
cox_fit_strata <- cox_model %>% 
  fit(Surv(time, status) ~ age + sex + wt.loss + strata(inst), data = lung)

# no strata info given
new_data_0 <- data.frame(age = c(50, 60), sex = 1, wt.loss = 5)

# predictions should error informatively
predict(cox_fit_strata, new_data = new_data_0, type = "time") 
#> # A tibble: 36 x 1
#>    .pred_time
#>         <dbl>
#>  1       418.
#>  2       302.
#>  3       455.
#>  4       584.
#>  5       465.
#>  6       355.
#>  7       479.
#>  8       388.
#>  9       471.
#> 10       430.
#> # … with 26 more rows
predict(cox_fit_strata, new_data = new_data_0, type = "survival", .time = 200)
#> Error in rep(1:n, x$strata): invalid 'times' argument

Created on 2021-03-30 by the reprex package (v1.0.0)

Survival probabilities for stratified Cox model

Something is wrong with the survival probabilities for the stratified Cox model (for both the "survival" and the "glmnet" engine): The prediction for time = 20 for the first observation is missing.

library(censored)
#> Loading required package: parsnip
library(survival)
library(tidyr)

# survival engine
spec_survival <- proportional_hazards() %>%
  set_mode("censored regression") %>%
  set_engine("survival")

set.seed(14)
fit_survival <- fit(spec_survival,
                    Surv(stop, event) ~ rx + size + number + strata(enum),
                    data = bladder)
#> Warning in .recacheSubclasses(def@className, def, env): undefined subclass
#> "numericVector" of class "Mnumeric"; definition not updated

pred_survival <- predict(fit_survival, new_data = bladder[1:2, ], 
                         type = "survival", time = c(10, 20))
pred_survival
#> # A tibble: 2 x 1
#>   .pred               
#>   <list>              
#> 1 <tibble[,2] [1 × 2]>
#> 2 <tibble[,2] [2 × 2]>
unnest(pred_survival, cols = c(".pred"))
#> # A tibble: 3 x 2
#>   .time .pred_survival
#>   <dbl>          <dbl>
#> 1    10          0.636
#> 2    10          0.934
#> 3    20          0.726

# glmnet engine
spec_glmnet <- proportional_hazards(penalty = 0.123) %>%
  set_mode("censored regression") %>%
  set_engine("glmnet")

set.seed(14)
fit_glmnet <- fit(spec_glmnet,
                  Surv(stop, event) ~ rx + size + number + strata(enum),
                  data = bladder)

pred_glmnet <- predict(fit_glmnet, new_data = bladder[1:2, ], type = "survival",
                       time = c(10, 20), penalty = 0.1)
pred_glmnet
#> # A tibble: 2 x 1
#>   .pred               
#>   <list>              
#> 1 <tibble[,2] [1 × 2]>
#> 2 <tibble[,2] [2 × 2]>
unnest(pred_glmnet, cols = c(".pred"))
#> # A tibble: 3 x 2
#>   .time .pred_survival
#>   <dbl>          <dbl>
#> 1    10          0.623
#> 2    10          0.929
#> 3    20          0.713

Created on 2021-06-08 by the reprex package (v2.0.0)

decision_tree() models

  • rpart time

  • rpart survival

  • party time

  • party survival

  • rpart parameters

  • party parameters

rpart

library(rpart)
library(survival)
library(pec)
#> Loading required package: prodlim
rp_mod <- rpart(Surv(time, status) ~ age + ph.ecog, data = lung)

## time
str(predict(rp_mod, lung))
#>  Named num [1:228] 0.973 1.159 0.464 0.973 0.464 ...
#>  - attr(*, "names")= chr [1:228] "1" "2" "3" "4" ...

## survival
pec_rpart_mod <- pecRpart(Surv(time, status) ~ age + ph.ecog, data = lung)
str(predictSurvProb(pec_rpart_mod, newdata = lung, times = c(100, 200)))
#>  num [1:228, 1:2] 0.891 0.833 0.951 0.891 0.951 ...
#>  - attr(*, "dimnames")=List of 2
#>   ..$ : NULL
#>   ..$ : NULL

Created on 2020-08-06 by the reprex package (v0.3.0)

party

library(party)
#> Loading required package: grid
#> Loading required package: mvtnorm
#> Loading required package: modeltools
#> Loading required package: stats4
#> Loading required package: strucchange
#> Loading required package: zoo
#> 
#> Attaching package: 'zoo'
#> The following objects are masked from 'package:base':
#> 
#>     as.Date, as.Date.numeric
#> Loading required package: sandwich
library(survival)
library(pec)
#> Loading required package: prodlim
ctree_mod <- ctree(Surv(time, status) ~ age + ph.ecog, data = lung)

# time
str(predict(ctree_mod, lung))
#>  num [1:228] 353 353 353 353 353 353 183 183 353 183 ...

# survival
pec_ctree_mod <- pecCtree(Surv(time, status) ~ age + ph.ecog, data = lung)
str(predictSurvProb(pec_ctree_mod, newdata = lung, times = c(100, 200)))
#>  num [1:228, 1:2] 0.893 0.893 0.893 0.893 0.893 ...

Created on 2020-08-06 by the reprex package (v0.3.0)

convert documentation to the new parsnip format

engines in censored

  • bag_tree_ipred
  • boost_tree_mboost
  • decision_tree_party
  • decision_tree_rpart
  • proportional_hazards_glmnet
  • proportional_hazards_survival
  • rand_forest_party
  • survival_reg_flexsurv
  • survival_reg_survival

survival probabilities for glmnet proportional hazards model

We want to get survival probabilities for glmnet objects that use family = "cox". This page has details on the Cox model with glmnet and describes the special approach that is required for this model implementation.

There are a few complications that arise when building a wrapper around this model:

No formulas

glmnet() does not use a formula method so something simple like Surv(time, event) ~ x_1 + x_2 isn't possible. In their examples, they has the contents of the Surv() object as a matrix (or an object of class Surv) to the glmnet() function. I believe that the censored package already deals with this.

The main consequence of no formula method is that, for stratification, we cannot use the canonical function (e.g., Surv(time, event) ~ x_1 + x_2 + strata(var)). There is a different function that is used on the outcome object to store the stratification variable. The syntax looks like stratifySurv(surv_object, strata_var).

Retain the training set

Like the survival package, survival probabilities for the Cox model are best computed by using the survfit() method on the model object. While glmnet does have a survfit() method for their data, it requires the original training set data to work.

We'll have to attach x and y data to the fitted glmnet object when the model is fit.

Predictions over penalty values

Like other glmnet objects, we can make predictions over many values of lambda for the same model object. For survival probabilities, this is also the case. When making such predictions, we also need to specify the time points. As a result, the standard nested tibble that censored produces will have a row for each combination of .time and lambda and would look something like:

# A tibble: 4 x 3
  .time .pred_survival penalty
  <dbl>          <dbl>   <dbl>
1     1          0.966    0.01
2    10          0.448    0.01
3     1          0.951    0.1 
4    10          0.421    0.1 

The initial work here is to make a function similar to censored::cph_survival_prob() to get the data in this format. It looks like survival:::survfit.coxph() produces a list of survfit objects for each value of lambda. It may not bee too complex to get these predictions and then reformat them (which we have code to do for the results of survival:::survfit.coxph()).

Note: recall that, when a glmnet model is fit via parsnip, we require a single penalty value (even though the model produces all of the coefficients for the entire path of penalty values). For this reason, predict.model_fit() will only produce predictions for a single penalty value. This function should produce the above output without the penalty column. The multi_predict() method will have an argument for the penalty values and its results will look like the tibble above.

linear predictors should increase with time

We might have to change the sign of the results when predict(type = "linear_pred") is used. A consistent direction is critical so that the performance metrics all work. I've already added a note in ?parsnip::predict.model_fit about this.

PH models and probability estimates

There is currently some code in aaa_survival_prop.R to estimate survival probabilities for Cox PH models. It might be better to use survift(coxph_object) or basehaz(coxph_object) to get these since they take into account any stratification variables for that model. If that's the case, we might want to save the table produced by those functions in the model object and interpolate the step functions within each strata. Dealing with the strata is a pain, because of how survival encodes them, but some of that work is already done for the parametric model engine.

Also, I don't see any canonical ways of estimating the baseline hazard so that we can get type = "hazard". There as some methods that layer another model on top of the PH model results. It would be good to get community feedback on this.

We might also consider having a type = "cumulative hazard" method

@dincerti

argument for flipping the linear predictor sign

We can write specific methods for proportional hazards engines that have an increasing argument that defaults to TRUE.

  • If TRUE, we would reverse the sing of the linear predictor (so that it is increasing with time).
  • If FALSE, do nothing.

There will have to be predict_linear_pred() methods for each engine. For example,

> proportional_hazards() %>%
+     set_engine("survival") %>%
+     fit(Surv(time, status) ~ x, data = aml) %>% 
+     class()
[1] "_coxph"    "model_fit"

so there would need to be predict_linear_pred._coxph() method (and a _coxnet method too).

For glmnet models, fail if no penalty value is provided

Similar to tidymodels/parsnip#481

library(censored)
#> Loading required package: parsnip
library(survival)

lung2 <- lung[-14, ]

# These should both fail in translate()

proportional_hazards(penalty = c(0.01, 0.1)) %>% 
  set_engine("glmnet") %>% 
  fit(Surv(time, status) ~ age + ph.ecog, data = lung2)
#> parsnip model object
#> 
#> Fit time:  13ms 
#> 
#> Call:  glmnet::glmnet(x = x, y = y, family = "cox", alpha = alpha, lambda = lambda) 
#> 
#>    Df %Dev   Lambda
#> 1   0 0.00 0.225300
#> 2   1 0.21 0.205300
#> 3   1 0.39 0.187100
#> 4   1 0.53 0.170500
#> 5   1 0.65 0.155300
#> 6   1 0.75 0.141500
#> 7   1 0.84 0.128900
#> 8   1 0.90 0.117500
#> 9   1 0.96 0.107000
#> 10  1 1.01 0.097540
#> 11  1 1.05 0.088870
#> 12  2 1.08 0.080980
#> 13  2 1.13 0.073790
#> 14  2 1.16 0.067230
#> 15  2 1.19 0.061260
#> 16  2 1.22 0.055820
#> 17  2 1.24 0.050860
#> 18  2 1.26 0.046340
#> 19  2 1.27 0.042220
#> 20  2 1.28 0.038470
#> 21  2 1.29 0.035050
#> 22  2 1.30 0.031940
#> 23  2 1.31 0.029100
#> 24  2 1.31 0.026520
#> 25  2 1.32 0.024160
#> 26  2 1.32 0.022010
#> 27  2 1.33 0.020060
#> 28  2 1.33 0.018280
#> 29  2 1.33 0.016650
#> 30  2 1.33 0.015170
#> 31  2 1.33 0.013830
#> 32  2 1.33 0.012600
#> 33  2 1.34 0.011480
#> 34  2 1.34 0.010460
#> 35  2 1.34 0.009530
#> 36  2 1.34 0.008683
#> 37  2 1.34 0.007912
#> 38  2 1.34 0.007209
#> 39  2 1.34 0.006568
#> 40  2 1.34 0.005985
#> 41  2 1.34 0.005453
#> The training data has been saved for prediction.

proportional_hazards() %>% 
  set_engine("glmnet") %>% 
  fit(Surv(time, status) ~ age + ph.ecog, data = lung2)
#> parsnip model object
#> 
#> Fit time:  9ms 
#> 
#> Call:  glmnet::glmnet(x = x, y = y, family = "cox", alpha = alpha, lambda = lambda) 
#> 
#>    Df %Dev   Lambda
#> 1   0 0.00 0.225300
#> 2   1 0.21 0.205300
#> 3   1 0.39 0.187100
#> 4   1 0.53 0.170500
#> 5   1 0.65 0.155300
#> 6   1 0.75 0.141500
#> 7   1 0.84 0.128900
#> 8   1 0.90 0.117500
#> 9   1 0.96 0.107000
#> 10  1 1.01 0.097540
#> 11  1 1.05 0.088870
#> 12  2 1.08 0.080980
#> 13  2 1.13 0.073790
#> 14  2 1.16 0.067230
#> 15  2 1.19 0.061260
#> 16  2 1.22 0.055820
#> 17  2 1.24 0.050860
#> 18  2 1.26 0.046340
#> 19  2 1.27 0.042220
#> 20  2 1.28 0.038470
#> 21  2 1.29 0.035050
#> 22  2 1.30 0.031940
#> 23  2 1.31 0.029100
#> 24  2 1.31 0.026520
#> 25  2 1.32 0.024160
#> 26  2 1.32 0.022010
#> 27  2 1.33 0.020060
#> 28  2 1.33 0.018280
#> 29  2 1.33 0.016650
#> 30  2 1.33 0.015170
#> 31  2 1.33 0.013830
#> 32  2 1.33 0.012600
#> 33  2 1.34 0.011480
#> 34  2 1.34 0.010460
#> 35  2 1.34 0.009530
#> 36  2 1.34 0.008683
#> 37  2 1.34 0.007912
#> 38  2 1.34 0.007209
#> 39  2 1.34 0.006568
#> 40  2 1.34 0.005985
#> 41  2 1.34 0.005453
#> The training data has been saved for prediction.

Created on 2021-05-12 by the reprex package (v2.0.0)

Multiple survival curves for single observation from stratified survival model

For a stratified Cox model, survival::survfit() should return a single survival curve per observation if the strata information is given in the newdata arg. If no information on the strata is given, it returns a curve for each stratum. This works as intended for newdata data frame which contain multiple observations, however if newdata only contains a single observation, survfit() returns a curve per stratum, even if the information for the stratum is given in newdata.

library(survival)

cfit <- coxph(Surv(time, status) ~ age + sex + wt.loss + strata(inst),
              data = lung)

# no strata info --> survival curves for all strata
new_data_0 <- expand.grid(age = c(50, 60), sex = 1, wt.loss = 5)
csurv_0 <- survfit(cfit, newdata = new_data_0)
dim(csurv_0)
#> strata   data
#>     18      2

# add strata info --> survival curves for only those
new_data_n <- data.frame(age = c(50, 60), sex = 1:2, wt.loss = 5, inst = c(6, 11))
csurv_n <- survfit(cfit, newdata = new_data_n)
dim(csurv_n)
#> strata
#>      2

# single observation with strata info --> should be one curve but gives all
new_data_1 <- data.frame(age = 60, sex = 2, wt.loss = 5, inst = 11)
csurv_1 <- survfit(cfit, newdata = new_data_1)
dim(csurv_1)
#> strata
#>     18

Created on 2021-03-29 by the reprex package (v1.0.0)

This has been reported as an issue for the {survival} repo.

Note to us: once this is fixed in {survival} check back here that this translates nicely to {censored}.

library(parsnip)
library(survival)
library(censored)

cox_model <- proportional_hazards() %>%
  set_engine("survival")
cox_fit_strata <- cox_model %>%
  fit(Surv(time, status) ~ age + sex + wt.loss + strata(inst), data = lung)

# strata info for n observations
new_data_n <- data.frame(age = c(50, 60), sex = 1:2, wt.loss = 5, inst = c(6, 11))
predict(cox_fit_strata, new_data = new_data_n, type = "time") # 2 rows = 2 observations
#> # A tibble: 2 x 1
#>   .pred_time
#>        <dbl>
#> 1       339.
#> 2       500.

# stratum info for single observation
new_data_1 <- data.frame(age = 60, sex = 2, wt.loss = 5, inst = 11)
predict(cox_fit_strata, new_data = new_data_1, type = "time") # gives 18 values (for all strata) when it should only give 1
#> # A tibble: 18 x 1
#>    .pred_time
#>         <dbl>
#>  1       497.
#>  2       364.
#>  3       532.
#>  4       651.
#>  5       555.
#>  6       421.
#>  7       558.
#>  8       449.
#>  9       540.
#> 10       511.
#> 11       612.
#> 12       624.
#> 13       583.
#> 14       364.
#> 15       589.
#> 16       818.
#> 17       792.
#> 18      1022

Created on 2021-03-29 by the reprex package (v1.0.0)

model names for mboost glm and gam

What should the model type of a {mboost} GLM with familty = coxPH() me? what about boosted GAM?

  • boost_glm()
  • boost_gam()

or should we include have the glm model be used in linear_reg()?

linear_reg() %>%
  set_mode("censored regression") %>%
  set_engine("mboost")

bag_tree() models

  • ipred survival
  • ipred time
  • ipred parameters
library(ipred)
library(survival)
library(pec)
#> Loading required package: prodlim

bag_mod <- bagging(Surv(time, status) ~ age + ph.ecog, data = lung)

bag_preds <- predict(bag_mod, lung)
bag_preds[[1]]
#> Call: survfit(formula = Surv(agglsample[[j]], aggcens[[j]]) ~ 1)
#> 
#>       n  events  median 0.95LCL 0.95UCL 
#>    1400    1068     345     306     363

# survival
summary(bag_preds[[1]], times = c(100, 200))
#> Call: survfit(formula = Surv(agglsample[[j]], aggcens[[j]]) ~ 1)
#> 
#>  time n.risk n.event survival std.err lower 95% CI upper 95% CI
#>   100   1270     130    0.907 0.00776        0.892        0.922
#>   200    976     234    0.738 0.01180        0.716        0.762

# Median survival
library(purrr)
map_dbl(bag_preds, ~ quantile(.x, probs = .5)$quantile)
#>   [1] 345 350 450 345 477 345 310 207 305 177 345 310 310 348 345 301 345 288
#>  [19] 177 345 301 433 305 473 353 345 477 288 305 183 345 163 181 177 177 177
#>  [37] 310 301 183 345 345 163 473 177 310 183 353 345 353 305 329 455 350 473
#>  [55] 363 455 310 177 363 345 201 305 329 293 345 310 310 450 350 301 455 305
#>  [73] 183 305 305 337 433 345 301 353 473 337 305 305 305 337 345 450 305 163
#>  [91] 455 345 301 345 353 163 345 337 345 353 473 345 345 345 177 301 337 301
#> [109] 337 201 353 305 284 310 337 183 180 310 310 284 310 177 345 163 353 301
#> [127] 305 345 183 450 363 305 450 337 337 183 363 177 337 450 163 345 183 293
#> [145] 345 305 353 337 301 473 345 477 455 305 177 310 310 177 345 345 455 305
#> [163] 310 345 345 450 345 450 345 310 345 473 345 550 305 363 177 337 305 353
#> [181] 477 550 310 293 433 450 163 473 345 433 329 180 305 345 337 353 345 345
#> [199] 433 345 345 329 455 433 455 177 433 180 345 363 477 301 345 310 293 293
#> [217] 305 183 288 450 301 337 285 285 550 201 301 363

Created on 2020-08-06 by the reprex package (v0.3.0)

Survnip + glmnet/coxnet?

Dear Emil, I´m new to github so I apologize in advance. I wonder if Survnip can be used for regularized Cox regression (LASSO/rigde/Elastic Net)? Kind regards David

Make sure NAs are handled correctly in prediction

library(survnip)
#> Loading required package: parsnip
library(survival)

cox_mod <-
  cox_reg() %>%
  set_engine("survival") %>%
  fit(Surv(time, status) ~ age + ph.ecog, data = lung)

dim(predict(cox_mod, new_data = lung, .time = 200))
#> [1] 227   1

dim(lung)
#> [1] 228  10

Created on 2020-08-03 by the reprex package (v0.3.0)

Move `master` branch to `main`

The master branch of this repository will soon be renamed to main, as part of a coordinated change across several GitHub organizations (including, but not limited to: tidyverse, r-lib, tidymodels, and sol-eng). We anticipate this will happen by the end of September 2021.

That will be preceded by a release of the usethis package, which will gain some functionality around detecting and adapting to a renamed default branch. There will also be a blog post at the time of this master --> main change.

The purpose of this issue is to:

  • Help us firm up the list of targetted repositories
  • Make sure all maintainers are aware of what's coming
  • Give us an issue to close when the job is done
  • Give us a place to put advice for collaborators re: how to adapt

message id: euphoric_snowdog

time points converted to character

For some models:

library(tidymodels)
library(censored)
library(survival)
library(splines)
f <- Surv(time, status) ~ ph.ecog + ns(age, df = 4) + strata(sex)
f2 <- Surv(time, status) ~ ph.ecog + age + sex

pred_times <- 1:10
pred_row <- c(57, 2, 3)
tree_fit <-
  decision_tree() %>% 
  set_engine("party") %>% 
  set_mode("censored regression") %>%
  fit(f2, data = lung)
tree_pred <-
  predict(tree_fit, lung[pred_row, ], type = "survival", .time = pred_times) %>%
  setNames(".pred")
tree_pred$.pred[[1]]
#> # A tibble: 10 x 2
#>    .time .pred_survival
#>    <chr>          <dbl>
#>  1 1              1    
#>  2 2              1    
#>  3 3              1    
#>  4 4              1    
#>  5 5              0.986
#>  6 6              0.986
#>  7 7              0.986
#>  8 8              0.986
#>  9 9              0.986
#> 10 10             0.986

rf_fit <-
  rand_forest(mtry = 1) %>% 
  set_engine("party") %>% 
  set_mode("censored regression") %>%
  fit(f2, data = lung)

rf_pred <-
  predict(rf_fit, lung[pred_row, ], type = "survival", .time = pred_times) %>%
  setNames(".pred")
rf_pred$.pred[[1]]
#> # A tibble: 10 x 2
#>    .time .pred_survival
#>    <chr>          <dbl>
#>  1 1              1    
#>  2 2              1    
#>  3 3              1    
#>  4 4              1    
#>  5 5              0.949
#>  6 6              0.949
#>  7 7              0.949
#>  8 8              0.949
#>  9 9              0.949
#> 10 10             0.949

Created on 2021-03-10 by the reprex package (v1.0.0.9000)

clarify error message for `remove_strata()`

f <- Surv(time, event) ~ (x + z) + strata(s)
censored:::remove_strata(f)
#> Surv(time, event) ~ (x + z)

f <- Surv(time, event) ~ x + (z + strata(s))
censored:::remove_strata(f)
#> Error: Stratification needs to be specified via `+ strata()`.

Created on 2021-06-17 by the reprex package (v2.0.0)

Incorrect interpolation of KM values

The steps in the Kaplan-Meier curve are left-inclusive corresponding to a drop in survival probability at an event time and then staying constant until the next event time.

In our interpolation helper function km_with_cuts() we currently cut with right = TRUE which corresponds to estimating the survival probability as the one of the next event time rather the previous.

library(censored)
#> Loading required package: parsnip
library(survival)

cox_fit <- coxph(Surv(stop, event) ~ rx + size + number + strata(enum),
                   data = bladder, x = TRUE)

# predict survival prob at time 20
new_data <- bladder[1:3, ]
times <- 20

# from survival_prob_cph()
y <- survival::survfit(cox_fit, newdata = new_data, conf.int = 0.95)

stack_sf <- censored:::stack_survfit(y, nrow(new_data)) %>%
  # simplify to observation 1
  dplyr::filter(.row == 1) %>%
  dplyr::bind_rows(censored:::prob_template, .)
# the two time points around 20
stack_sf %>%
  dplyr::filter(18 <= .time & .time <= 22)
#> # A tibble: 2 x 6
#>   .time .pred_survival .pred_survival_l… .pred_survival_… .pred_hazard_cu…  .row
#>   <dbl>          <dbl>             <dbl>            <dbl>            <dbl> <int>
#> 1    18          0.535             0.418            0.685            0.625     1
#> 2    22          0.519             0.401            0.672            0.655     1

# from interpolate_km_values_ungrouped()
stack_sf <- censored:::km_with_cuts(stack_sf)
stack_sf %>%
  dplyr::filter(18 <= .time & .time <= 22) %>%
  dplyr::select(.time, .pred_survival, .row, .cuts)
#> # A tibble: 2 x 4
#>   .time .pred_survival  .row .cuts  
#>   <dbl>          <dbl> <int> <fct>  
#> 1    18          0.535     1 (17,18]
#> 2    22          0.519     1 (18,22]
# ^^ this should be cut with left-inclusive intervals:
# the estimated survival probability drops at the event time and is then
# constant until the next event,
# meaning that for time 20 the estimate should 0.535.
# however, our join gives 0.519
tibble::tibble(.time = times) %>%
  censored:::km_with_cuts(times = stack_sf$.time) %>%
  dplyr::rename(.tmp = .time) %>%
  dplyr::left_join(stack_sf, by = ".cuts") %>%
  dplyr::select(.time = .tmp, .cuts, .pred_survival)
#> # A tibble: 1 x 3
#>   .time .cuts   .pred_survival
#>   <dbl> <fct>            <dbl>
#> 1    20 (18,22]          0.519


# this also estimates the survival probability as 0.535
pec::predictSurvProb(cox_fit, new_data[1,], times = 20)
#> Warning in .recacheSubclasses(def@className, def, env): undefined subclass
#> "numericVector" of class "Mnumeric"; definition not updated
#>           [,1]
#> [1,] 0.5349973

Created on 2021-06-16 by the reprex package (v2.0.0)

rand_forest() models

  • party time
  • party survival
  • randomForestSRC time
  • party parameters
  • randomForestSRC parameters
library(party)
#> Loading required package: grid
#> Loading required package: mvtnorm
#> Loading required package: modeltools
#> Loading required package: stats4
#> Loading required package: strucchange
#> Loading required package: zoo
#> 
#> Attaching package: 'zoo'
#> The following objects are masked from 'package:base':
#> 
#>     as.Date, as.Date.numeric
#> Loading required package: sandwich
library(survival)
library(pec)
#> Loading required package: prodlim

set.seed(342)
cforest_mod <-
  cforest(Surv(time, status) ~ age + ph.ecog, data = lung,
          controls = cforest_unbiased(ntree = 100, mtry = 1))

# time
predict(cforest_mod, newdata = lung)
#>   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20 
#> 306 353 428 306 558 306 288 222 286 177 306 288 371 305 306 283 345 222 180 306 
#>  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40 
#> 283 371 305 433 350 345 558 239 286 208 371 199 199 177 177 177 285 283 208 345 
#>  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59  60 
#> 345 199 433 177 371 201 350 345 350 305 301 477 353 433 345 455 350 177 348 345 
#>  61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80 
#> 208 305 285 283 371 288 285 433 353 283 477 305 208 353 286 310 371 363 303 320 
#>  81  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 
#> 433 363 353 353 353 310 310 455 353 201 477 345 283 306 353 201 371 310 345 353 
#> 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 
#> 433 371 363 345 177 283 363 283 363 208 353 353 283 348 363 201 199 288 285 283 
#> 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 
#> 348 177 371 201 353 283 305 345 201 371 345 286 371 363 350 208 348 180 363 428 
#> 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 
#> 199 306 201 283 306 286 353 363 303 433 345 558 455 286 180 288 371 177 345 363 
#> 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 
#> 455 353 288 345 345 394 301 371 301 350 301 433 363 433 286 345 180 363 286 350 
#> 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 
#> 558 433 350 283 371 428 201 433 345 371 301 180 305 345 310 353 345 345 371 301 
#> 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 
#> 345 285 477 371 477 177 394 199 371 345 558 283 371 285 283 283 353 201 239 394 
#> 221 222 223 224 225 226 227 228 
#> 283 310 283 283 433 208 283 348

# survival
set.seed(342)
pec_cforest_mod <-
  pecCforest(Surv(time, status) ~ age + ph.ecog, data = lung,
             controls = cforest_unbiased(ntree = 100, mtry = 1))

str(predictSurvProb(pec_cforest_mod, newdata = lung, times = c(100, 200)))
#>  num [1:228, 1:2] 0.869 0.839 0.936 0.895 0.93 ...

Created on 2020-08-06 by the reprex package (v0.3.0)

randomForestSRC

library(randomForestSRC)
#> 
#>  randomForestSRC 2.9.3 
#>  
#>  Type rfsrc.news() to see new features, changes, and bug fixes. 
#> 
library(survival)
library(pec)
#> Loading required package: prodlim

rfsrce_mod <- rfsrc(Surv(time, status) ~ age + ph.ecog, data = lung, ntree = 1000)
# The predict function generates an object with classes "rfsrc", "predict", and "surv". It is unclear what is being predicted:

# time?
str(predict(rfsrce_mod, lung)$predicted)
#>  num [1:227, 1:2] 163 145 344 170 328 ...
#>  - attr(*, "dimnames")=List of 2
#>   ..$ : NULL
#>   ..$ : chr [1:2] "event.1" "event.2"

Created on 2020-08-06 by the reprex package (v0.3.0)

improve `multi_predict()`

Rather than relying on the name/.id behavior in map_dfr(), you might consider doing map2_dfr(y, penalty, ~stack_survfit(.x, n = nrow(new_data), penalty = .y)).

You could have this signature stack_survfit(x, n, penalty = NULL) and just unconditionally add the penalty column to the front of the tibbles you create. tibble() will automatically drop NULL columns, and will recycle penalty of size 1 as required.

library(tibble)
tibble(penalty = NULL, y = 1:2)
#> # A tibble: 2 x 1
#>       y
#>   <int>
#> 1     1
#> 2     2
tibble(penalty = 0.01, y = 1:2)
#> # A tibble: 2 x 2
#>   penalty     y
#>     <dbl> <int>
#> 1    0.01     1
#> 2    0.01     2

Originally posted by @DavisVaughan in #70 (comment)

Missing values when getting survival probabilities for coxnet models

This issue is about how the survfit() method for coxnet models handles missing values. There is a corresponding issue #40 for coxph models.

We would like results of survfit() to be padded for observations with missing values, however, na.exclude() does not do this consistently, it depends on the penalty.

library(survival)
library(glmnet)
#> Loading required package: Matrix
#> Loaded glmnet 4.1-1

# prep and fit non-stratified and stratified model
lung2 <- lung[-14,] # includes NA, glmnet doesn't fit with NA
lung_x <- as.matrix(lung2[, c("age", "ph.ecog")])
lung_y <- Surv(lung2$time, lung2$status)
fit_glmnet <- glmnet(lung_x, lung_y, family = "cox")
lung_y_s <- stratifySurv(lung_y, lung2$sex)
fit_glmnet_s <- glmnet(lung_x, lung_y_s, family = "cox")
#> Warning: cox.fit: algorithm did not converge

# non-stratified
new_x <- as.matrix(lung[14:16, c("age", "ph.ecog")])
sf <- survfit(fit_glmnet, newx = new_x, x = lung_x, y = lung_y,
              na.action = na.exclude)

# includes all 3 observations
sf[[1]] 
#> Call: survfit.coxnet(formula = fit_glmnet, newx = new_x, x = lung_x, 
#>     y = lung_y, na.action = na.exclude)
#> 
#>      n events median
#> 14 227    164    310
#> 15 227    164    310
#> 16 227    164    310

# includes only 2
sf[[2]] 
#> Call: survfit.coxnet(formula = fit_glmnet, newx = new_x, x = lung_x, 
#>     y = lung_y, na.action = na.exclude)
#> 
#>      n events median
#> 15 227    164    310
#> 16 227    164    310

# stratified
new_x <- as.matrix(lung[14:16, c("age", "ph.ecog")])
new_strata <- as.matrix(lung[14:16, "sex", drop = FALSE])
sf_s <- survfit(fit_glmnet_s, newx = new_x, newstrata = new_strata, 
              x = lung_x, y = lung_y_s, na.action = na.exclude)

# includes all 3 observations
sf_s[[1]] 
#> Call: survfit.coxnet(formula = fit_glmnet_s, newx = new_x, newstrata = new_strata, 
#>     x = lung_x, y = lung_y_s, na.action = na.exclude)
#> 
#>      n events median
#> 14 137    111    270
#> 15 137    111    270
#> 16 137    111    270

# includes only 2
sf_s[[2]] 
#> Call: survfit.coxnet(formula = fit_glmnet_s, newx = new_x, newstrata = new_strata, 
#>     x = lung_x, y = lung_y_s, na.action = na.exclude)
#> 
#>      n events median
#> 15 137    111    270
#> 16 137    111    270

Created on 2021-06-15 by the reprex package (v2.0.0)

multi-predict for some prediction types

With glmnet (and maybe mboost) we can get predictions one many values of some tuning parameters without multiple model fits. parsnip saves those as nested tibbles per row of new_data.

This is also how we store the predictions of hazard and survival probabilities.

We might be able to accommodate both by having a single nested for both. For example, for glmnet, we would have the full grid of penalty and .time values with an additional column for the prediction.

missing values in data when getting survival probabilities with coxph()

It looks like na.exclude() doesn't pad the results of survfit.coxph() with NA values (in the resulting matrix). This happens with or without strata.

library(survival)
# survival    * 3.2-7      2020-09-28 [1] CRAN (R 4.0.3) 

mod <- coxph(Surv(time, status) ~ age + ph.ecog, data = lung,
             na.action = na.exclude)
# lung$ph.ecog[14] is NA
new_x <- lung[1:15, c("ph.ecog", "age")]

length(predict(mod, new_x, na.action = na.exclude))
#> [1] 15
surv_estimates <- survfit(mod, newdata = new_x,
                          na.action = na.exclude)
dim(surv_estimates$surv)
#> [1] 185  14

Created on 2021-03-09 by the reprex package (v1.0.0.9000)

The survival maintainer will not fix this

For operations that return a value per subject, then na.action plays a role. But that is not what a survival curve is.

What is the compeling use case that would justify months of work? You would need to start by finding all the downstream routines that would be affected, map out what they should do, then whether they now need an na.omit option. document it, and test it. The use case would need to be very strong.

boost_tree() models

  • mboost - survival

  • mboost - linear_pred

  • mboost - parameters

library(mboost)
#> Loading required package: parallel
#> Loading required package: stabs
#> This is mboost 2.9-3. See 'package?mboost' and 'news(package  = "mboost")'
#> for a complete list of changes.
library(survival)
glmb_mod <-
  blackboost(Surv(time, status) ~ age + ph.ecog, 
             data = lung,
             family = CoxPH())

## Survival time
str(survFit(glmb_mod, newdata = lung)$surv)
#>  num [1:138, 1:228] 0.996 0.984 0.98 0.972 0.968 ...
#>  - attr(*, "dimnames")=List of 2
#>   ..$ : NULL
#>   ..$ : chr [1:228] "1" "2" "3" "4" ...

## linear predictor
str(predict(glmb_mod, lung))
#>  num [1:228, 1] 0.00999 0.1597 -0.72502 -0.06093 -0.73367 ...
#>  - attr(*, "dimnames")=List of 2
#>   ..$ : chr [1:228] "4" "3" "3" "4" ...
#>   ..$ : NULL

Created on 2020-08-06 by the reprex package (v0.3.0)

move some packages to suggests

If we don't use code that is pkg::func here, we should move the dependency to Suggests and use rlang::check_installed() when it is called (which I think happens automatically in parsnip).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.