Giter Club home page Giter Club logo

tidyaml's Introduction

tidyAML

CRAN_Status_Badge Lifecycle: experimental PRs Welcome

Introduction

Welcome to {tidyAML} which is a new R package that makes it easy to use the tidymodels ecosystem to perform automated machine learning (AutoML). This package provides a simple and intuitive interface that allows users to quickly generate machine learning models without worrying about the underlying details. It also includes a safety mechanism that ensures that the package will fail gracefully if any required extension packages are not installed on the user’s machine. With {tidyAML}, users can easily build high-quality machine learning models in just a few lines of code. Whether you are a beginner or an experienced machine learning practitioner, {tidyAML} has something to offer.

Some ideas are that we should be able to generate regression models on the fly without having to actually go through the process of building the specification, especially if it is a non-tuning model, meaning we are not planing on tuning hyper-parameters like penalty and cost.

The idea is not to re-write the excellent work the tidymodels team has done (because it’s not possible) but rather to try and make an enhanced easy to use set of functions that do what they say and can generate many models and predictions at once.

This is similar to the great h2o package, but, {tidyAML} does not require java to be setup properly like h2o because {tidyAML} is built on tidymodels.

Thanks

Thank you Garrick Aden-Buie for the easy name change suggestion.

Installation

You can install {tidyAML} like so:

#install.packages("tidyAML")

Or the development version from GitHub

# install.packages("devtools")
#devtools::install_github("spsanderson/tidyAML")

Examples

Part of the reason to use {tidyAML} is so that you can generate many models of your data set. One way of modeling a data set is using regression for some numeric output. There is a convienent function in tidyAML that will generate a set of non-tuning models for fast regression. Let’s take a look below.

First let’s load the library

library(tidyAML)
#> Loading required package: parsnip
#> 
#> == Welcome to tidyAML ===========================================================================
#> If you find this package useful, please leave a star: 
#>    https://github.com/spsanderson/tidyAML'
#> 
#> If you encounter a bug or want to request an enhancement please file an issue at:
#>    https://github.com/spsanderson/tidyAML/issues
#> 
#> It is suggested that you run tidymodels::tidymodel_prefer() to set the defaults for your session.
#> 
#> Thank you for using tidyAML!

Now lets see the function in action.

fast_regression_parsnip_spec_tbl(.parsnip_fns = "linear_reg")
#> # A tibble: 11 × 5
#>    .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec
#>        <int> <chr>           <chr>         <chr>        <list>    
#>  1         1 lm              regression    linear_reg   <spec[+]> 
#>  2         2 brulee          regression    linear_reg   <spec[+]> 
#>  3         3 gee             regression    linear_reg   <spec[+]> 
#>  4         4 glm             regression    linear_reg   <spec[+]> 
#>  5         5 glmer           regression    linear_reg   <spec[+]> 
#>  6         6 glmnet          regression    linear_reg   <spec[+]> 
#>  7         7 gls             regression    linear_reg   <spec[+]> 
#>  8         8 lme             regression    linear_reg   <spec[+]> 
#>  9         9 lmer            regression    linear_reg   <spec[+]> 
#> 10        10 stan            regression    linear_reg   <spec[+]> 
#> 11        11 stan_glmer      regression    linear_reg   <spec[+]>
fast_regression_parsnip_spec_tbl(.parsnip_eng = c("lm","glm"))
#> # A tibble: 3 × 5
#>   .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec
#>       <int> <chr>           <chr>         <chr>        <list>    
#> 1         1 lm              regression    linear_reg   <spec[+]> 
#> 2         2 glm             regression    linear_reg   <spec[+]> 
#> 3         3 glm             regression    poisson_reg  <spec[+]>
fast_regression_parsnip_spec_tbl(.parsnip_eng = c("lm","glm","gee"), 
                                 .parsnip_fns = "linear_reg")
#> # A tibble: 3 × 5
#>   .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec
#>       <int> <chr>           <chr>         <chr>        <list>    
#> 1         1 lm              regression    linear_reg   <spec[+]> 
#> 2         2 gee             regression    linear_reg   <spec[+]> 
#> 3         3 glm             regression    linear_reg   <spec[+]>

As shown we can easily select the models we want either by choosing the supported parsnip function like linear_reg() or by choose the desired engine, you can also use them both in conjunction with each other!

This function also does add a class to the output. Let’s see it.

class(fast_regression_parsnip_spec_tbl())
#> [1] "tidyaml_mod_spec_tbl" "fst_reg_spec_tbl"     "tidyaml_base_tbl"    
#> [4] "tbl_df"               "tbl"                  "data.frame"

We see that there are two added classes, first fst_reg_spec_tbl because this creates a set of non-tuning regression models and then tidyaml_mod_spec_tbl because this is a model specification tibble built with {tidyAML}

Now, what if you want to create a non-tuning model spec without using the fast_regression_parsnip_spec_tbl() function. Well, you can. The function is called create_model_spec().

create_model_spec(
 .parsnip_eng = list("lm","glm","glmnet","cubist"),
 .parsnip_fns = list(
      "linear_reg",
      "linear_reg",
      "linear_reg",
      "cubist_rules"
     )
 )
#> # A tibble: 4 × 4
#>   .parsnip_engine .parsnip_mode .parsnip_fns .model_spec
#>   <chr>           <chr>         <chr>        <list>     
#> 1 lm              regression    linear_reg   <spec[+]>  
#> 2 glm             regression    linear_reg   <spec[+]>  
#> 3 glmnet          regression    linear_reg   <spec[+]>  
#> 4 cubist          regression    cubist_rules <spec[+]>

create_model_spec(
 .parsnip_eng = list("lm","glm","glmnet","cubist"),
 .parsnip_fns = list(
      "linear_reg",
      "linear_reg",
      "linear_reg",
      "cubist_rules"
     ),
 .return_tibble = FALSE
 )
#> $.parsnip_engine
#> $.parsnip_engine[[1]]
#> [1] "lm"
#> 
#> $.parsnip_engine[[2]]
#> [1] "glm"
#> 
#> $.parsnip_engine[[3]]
#> [1] "glmnet"
#> 
#> $.parsnip_engine[[4]]
#> [1] "cubist"
#> 
#> 
#> $.parsnip_mode
#> $.parsnip_mode[[1]]
#> [1] "regression"
#> 
#> 
#> $.parsnip_fns
#> $.parsnip_fns[[1]]
#> [1] "linear_reg"
#> 
#> $.parsnip_fns[[2]]
#> [1] "linear_reg"
#> 
#> $.parsnip_fns[[3]]
#> [1] "linear_reg"
#> 
#> $.parsnip_fns[[4]]
#> [1] "cubist_rules"
#> 
#> 
#> $.model_spec
#> $.model_spec[[1]]
#> Linear Regression Model Specification (regression)
#> 
#> Computational engine: lm 
#> 
#> 
#> $.model_spec[[2]]
#> Linear Regression Model Specification (regression)
#> 
#> Computational engine: glm 
#> 
#> 
#> $.model_spec[[3]]
#> Linear Regression Model Specification (regression)
#> 
#> Computational engine: glmnet 
#> 
#> 
#> $.model_spec[[4]]
#> Cubist Model Specification (regression)
#> 
#> Computational engine: cubist

Now the reason we are here. Let’s take a look at the first function for modeling with {tidyAML}, fast_regression().

library(recipes)
library(dplyr)

rec_obj <- recipe(mpg ~ ., data = mtcars)
frt_tbl <- fast_regression(
  .data = mtcars, 
  .rec_obj = rec_obj, 
  .parsnip_eng = c("lm","glm","gee"),
  .parsnip_fns = "linear_reg",
  .drop_na = FALSE
)

glimpse(frt_tbl)
#> Rows: 3
#> Columns: 8
#> $ .model_id       <int> 1, 2, 3
#> $ .parsnip_engine <chr> "lm", "gee", "glm"
#> $ .parsnip_mode   <chr> "regression", "regression", "regression"
#> $ .parsnip_fns    <chr> "linear_reg", "linear_reg", "linear_reg"
#> $ model_spec      <list> [~NULL, ~NULL, NULL, regression, TRUE, NULL, lm, TRUE]…
#> $ wflw            <list> [cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb, mp…
#> $ fitted_wflw     <list> [cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb, mp…
#> $ pred_wflw       <list> [<tbl_df[64 x 3]>], <NULL>, [<tbl_df[64 x 3]>]

As we see above, one of the models has gracefully failed, thanks in part to the function purrr::safely(), which was used to make what I call safe_make functions.

Let’s look at the fitted workflow predictions.

frt_tbl$pred_wflw
#> [[1]]
#> # A tibble: 64 × 3
#>    .data_category .data_type .value
#>    <chr>          <chr>       <dbl>
#>  1 actual         actual       15.2
#>  2 actual         actual       19.2
#>  3 actual         actual       22.8
#>  4 actual         actual       33.9
#>  5 actual         actual       26  
#>  6 actual         actual       19.2
#>  7 actual         actual       15  
#>  8 actual         actual       27.3
#>  9 actual         actual       24.4
#> 10 actual         actual       17.3
#> # ℹ 54 more rows
#> 
#> [[2]]
#> NULL
#> 
#> [[3]]
#> # A tibble: 64 × 3
#>    .data_category .data_type .value
#>    <chr>          <chr>       <dbl>
#>  1 actual         actual       15.2
#>  2 actual         actual       19.2
#>  3 actual         actual       22.8
#>  4 actual         actual       33.9
#>  5 actual         actual       26  
#>  6 actual         actual       19.2
#>  7 actual         actual       15  
#>  8 actual         actual       27.3
#>  9 actual         actual       24.4
#> 10 actual         actual       17.3
#> # ℹ 54 more rows

Now let’s load the multilevelmod library so that we can run the gee linear regression.

library(multilevelmod)

rec_obj <- recipe(mpg ~ ., data = mtcars)
frt_tbl <- fast_regression(
  .data = mtcars, 
  .rec_obj = rec_obj, 
  .parsnip_eng = c("lm","glm","gee"),
  .parsnip_fns = "linear_reg"
)

extract_wflw_pred(frt_tbl, 1:3)
#> # A tibble: 192 × 4
#>    .model_type     .data_category .data_type .value
#>    <chr>           <chr>          <chr>       <dbl>
#>  1 lm - linear_reg actual         actual       32.4
#>  2 lm - linear_reg actual         actual       14.3
#>  3 lm - linear_reg actual         actual       15.8
#>  4 lm - linear_reg actual         actual       30.4
#>  5 lm - linear_reg actual         actual       24.4
#>  6 lm - linear_reg actual         actual       15  
#>  7 lm - linear_reg actual         actual       33.9
#>  8 lm - linear_reg actual         actual       22.8
#>  9 lm - linear_reg actual         actual       19.2
#> 10 lm - linear_reg actual         actual       21.4
#> # ℹ 182 more rows

Getting Regression Residuals

Getting residuals is easy with {tidyAML}. Let’s take a look.

extract_regression_residuals(frt_tbl)
#> [[1]]
#> # A tibble: 32 × 4
#>    .model_type     .actual .predicted .resid
#>    <chr>             <dbl>      <dbl>  <dbl>
#>  1 lm - linear_reg    32.4       27.5  4.94 
#>  2 lm - linear_reg    14.3       14.2  0.121
#>  3 lm - linear_reg    15.8       18.5 -2.71 
#>  4 lm - linear_reg    30.4       30.6 -0.178
#>  5 lm - linear_reg    24.4       22.6  1.82 
#>  6 lm - linear_reg    15         13.3  1.69 
#>  7 lm - linear_reg    33.9       29.3  4.64 
#>  8 lm - linear_reg    22.8       25.3 -2.53 
#>  9 lm - linear_reg    19.2       17.6  1.62 
#> 10 lm - linear_reg    21.4       21.2  0.162
#> # ℹ 22 more rows
#> 
#> [[2]]
#> # A tibble: 32 × 4
#>    .model_type      .actual .predicted  .resid
#>    <chr>              <dbl>      <dbl>   <dbl>
#>  1 gee - linear_reg    32.4       27.5  4.95  
#>  2 gee - linear_reg    14.3       14.2  0.0928
#>  3 gee - linear_reg    15.8       18.5 -2.67  
#>  4 gee - linear_reg    30.4       30.5 -0.147 
#>  5 gee - linear_reg    24.4       22.6  1.83  
#>  6 gee - linear_reg    15         13.3  1.68  
#>  7 gee - linear_reg    33.9       29.3  4.65  
#>  8 gee - linear_reg    22.8       25.3 -2.53  
#>  9 gee - linear_reg    19.2       17.6  1.60  
#> 10 gee - linear_reg    21.4       21.2  0.165 
#> # ℹ 22 more rows
#> 
#> [[3]]
#> # A tibble: 32 × 4
#>    .model_type      .actual .predicted .resid
#>    <chr>              <dbl>      <dbl>  <dbl>
#>  1 glm - linear_reg    32.4       27.5  4.94 
#>  2 glm - linear_reg    14.3       14.2  0.121
#>  3 glm - linear_reg    15.8       18.5 -2.71 
#>  4 glm - linear_reg    30.4       30.6 -0.178
#>  5 glm - linear_reg    24.4       22.6  1.82 
#>  6 glm - linear_reg    15         13.3  1.69 
#>  7 glm - linear_reg    33.9       29.3  4.64 
#>  8 glm - linear_reg    22.8       25.3 -2.53 
#>  9 glm - linear_reg    19.2       17.6  1.62 
#> 10 glm - linear_reg    21.4       21.2  0.162
#> # ℹ 22 more rows

You can also pivot them into a long format making plotting easy with ggplot2.

extract_regression_residuals(frt_tbl, .pivot_long = TRUE)
#> [[1]]
#> # A tibble: 96 × 3
#>    .model_type     name        value
#>    <chr>           <chr>       <dbl>
#>  1 lm - linear_reg .actual    32.4  
#>  2 lm - linear_reg .predicted 27.5  
#>  3 lm - linear_reg .resid      4.94 
#>  4 lm - linear_reg .actual    14.3  
#>  5 lm - linear_reg .predicted 14.2  
#>  6 lm - linear_reg .resid      0.121
#>  7 lm - linear_reg .actual    15.8  
#>  8 lm - linear_reg .predicted 18.5  
#>  9 lm - linear_reg .resid     -2.71 
#> 10 lm - linear_reg .actual    30.4  
#> # ℹ 86 more rows
#> 
#> [[2]]
#> # A tibble: 96 × 3
#>    .model_type      name         value
#>    <chr>            <chr>        <dbl>
#>  1 gee - linear_reg .actual    32.4   
#>  2 gee - linear_reg .predicted 27.5   
#>  3 gee - linear_reg .resid      4.95  
#>  4 gee - linear_reg .actual    14.3   
#>  5 gee - linear_reg .predicted 14.2   
#>  6 gee - linear_reg .resid      0.0928
#>  7 gee - linear_reg .actual    15.8   
#>  8 gee - linear_reg .predicted 18.5   
#>  9 gee - linear_reg .resid     -2.67  
#> 10 gee - linear_reg .actual    30.4   
#> # ℹ 86 more rows
#> 
#> [[3]]
#> # A tibble: 96 × 3
#>    .model_type      name        value
#>    <chr>            <chr>       <dbl>
#>  1 glm - linear_reg .actual    32.4  
#>  2 glm - linear_reg .predicted 27.5  
#>  3 glm - linear_reg .resid      4.94 
#>  4 glm - linear_reg .actual    14.3  
#>  5 glm - linear_reg .predicted 14.2  
#>  6 glm - linear_reg .resid      0.121
#>  7 glm - linear_reg .actual    15.8  
#>  8 glm - linear_reg .predicted 18.5  
#>  9 glm - linear_reg .resid     -2.71 
#> 10 glm - linear_reg .actual    30.4  
#> # ℹ 86 more rows

tidyaml's People

Contributors

olivroy avatar spsanderson avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

tidyaml's Issues

May have to drop support for `mixOmics` since I can't get it installed on my machine.

#' Internals Make Base Regression Tibble
#'
#' @family Internals
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @details Creates a base tibble to create parsnip regression model specifications.
#'
#' @description Creates a base tibble to create parsnip regression model specifications.
#'
#' @examples
#' make_regression_base_tbl()
#'
#' @return
#' A tibble
#'
#' @name make_regression_base_tbl
NULL

#' @export
#' @rdname make_regression_base_tbl

make_regression_base_tbl <- function(){

  # Make the regression tribble
  mod_tbl <- tibble::tribble(
    ~.parsnip_engine, ~.parsnip_mode, ~.parsnip_fns,
    "lm", "regression", "linear_reg",
    "brulee", "regression", "linear_reg",
    "gee", "regression", "linear_reg",
    "glm","regression","linear_reg",
    "glmer","regression","linear_reg",
    "glmnet","regression","linear_reg",
    "gls","regression","linear_reg",
    "lme","regression","linear_reg",
    "lmer","regression","linear_reg",
    "stan","regression","linear_reg",
    "stan_glmer","regression","linear_reg",
    "Cubist","regression","cubist_rules",
    "glm","regression","poisson_reg",
    "gee","regression","poisson_reg",
    "glmer","regression","poisson_reg",
    "glmnet","regression","poisson_reg",
    "hurdle","regression","poisson_reg",
    "stan","regression","poisson_reg",
    "stan_glmer","regression","poisson_reg",
    "zeroinfl","regression","poisson_reg",
    "earth","regression","bag_mars",
    "rpart","regression","bag_tree",
    "dbarts","regression","bart",
    "xgboost","regression","boost_tree",
    "lightgbm","regression","boost_tree",
    "rpart","regression","decision_tree",
    "partykit","regression","decision_tree",
    "mgcv","regression","gen_additive_mod",
    "earth","regression","mars",
    "nnet","regression","mlp",
    "brulee","regression","mlp",
    "kknn","regression","nearest_neighbor",
    "ranger","regression","rand_forest",
    "randomForest","regression","rand_forest",
    "xrf","regression","rule_fit",
    "LiblineaR","regression","svm_linear",
    "kernlab","regression","svm_linear",
    "kernlab","regression","svm_poly",
    "kernlab","regression","svm_rbf"
  )

  # Return
  class(mod_tbl) <- c("tidyaml_base_tbl", class(mod_tbl))
  return(mod_tbl)
}

#' Internals Make Base Classification Tibble
#'
#' @family Internals
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @details Creates a base tibble to create parsnip classification model specifications.
#'
#' @description Creates a base tibble to create parsnip classification model specifications.
#'
#' @examples
#' make_classification_base_tbl()
#'
#' @return
#' A tibble
#'
#' @name make_classification_base_tbl
NULL

#' @export
#' @rdname make_classification_base_tbl

make_classification_base_tbl <- function(){

  # Make the regression tribble
  # Make tibble
  mod_tbl <- tibble::tribble(
    ~.parsnip_engine, ~.parsnip_mode, ~.parsnip_fns,
    "earth","classification","bag_mars",
    "earth","classification","discrim_flexible",
    "dbarts","classification","bart",
    "MASS","classification","discrim_linear",
    "mda","classification","discrim_linear",
    "sda","classification","discrim_linear",
    "sparsediscrim","classification","discrim_linear",
    "MASS","classification","discrim_quad",
    "sparsediscrim","classification","discrim_quad",
    "klaR","classification","discrim_regularized",
    "mgcv","classification","gen_additive_mod",
    "brulee","classification","logistic_reg",
    "gee","classification","logistic_reg",
    "glm","classification","logistic_reg",
    "glmer","classification","logistic_reg",
    "glmnet","classification","logistic_reg",
    "LiblineaR","classification","logistic_reg",
    "earth","classification","mars",
    "brulee","classification","mlp",
    "nnet","classification","mlp",
    "brulee","classification","multinom_reg",
    "glmnet","classification","multinom_reg",
    "nnet","classification","multinom_reg",
    "klaR","classification","naive_Bayes",
    "kknn","classification","nearest_neighbor",
    "xrf","classification","rule_fit",
    "kernlab","classification","svm_linear",
    "LiblineaR","classification","svm_linear",
    "kernlab","classification","svm_poly",
    "kernlab","classification","svm_rbf",
    "liquidSVM","classification","svm_rbf"
  )

  # Return
  class(mod_tbl) <- c("tidyaml_base_tbl", class(mod_tbl))
  return(mod_tbl)
}

Update `fast_regression()` to use new internal functions

Function:

#' Generate Model Specification calls to `parsnip`
#'
#' @family Model_Generator
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @details With this function you can generate a tibble output of any regression
#' model specification and it's fitted `workflow` object.
#'
#' @description Creates a list/tibble of parsnip model specifications.
#'
#' @param .data The data being passed to the function for the regression problem
#' @param .rec_obj The recipe object being passed.
#' @param .parsnip_fns The default is 'all' which will create all possible
#' regression model specifications supported.
#' @param .parsnip_eng the default is 'all' which will create all possible
#' regression model specifications supported.
#' @param .split_type The default is 'initial_split', you can pass any type of
#' split supported by `rsample`
#' @param .split_args The default is NULL, when NULL then the default parameters
#' of the split type will be executed for the rsample split type.
#'
#' @examples
#' library(recipes, quietly = TRUE)
#' library(dplyr, quietly = TRUE)
#'
#' rec_obj <- recipe(mpg ~ ., data = mtcars)
#' frt_tbl <- fast_regression(mtcars, rec_obj, .parsnip_eng = c("lm","glm"),
#' .parsnip_fns = "linear_reg")
#' glimpse(frt_tbl)
#'
#' @return
#' A list or a tibble.
#'
#' @name fast_regression
NULL

#' @export
#' @rdname fast_regression
#'

fast_regression <- function(.data, .rec_obj, .parsnip_fns = "all",
                            .parsnip_eng = "all", .split_type = "initial_split",
                            .split_args = NULL){
  
  # Tidy Eval ----
  call <- list(.parsnip_fns) %>%
    purrr::flatten_chr()
  engine <- list(.parsnip_eng) %>%
    purrr::flatten_chr()
  
  rec_obj <- .rec_obj
  split_type <- .split_type
  split_args <- .split_args
  
  # Checks ----
  
  # Get data splits
  df <- dplyr::as_tibble(.data)
  splits_obj <- create_splits(
    .data = df,
    .split_type = split_type,
    .split_args = split_args
  )
  
  # Generate Model Spec Tbl
  mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
    .parsnip_fns = call,
    .parsnip_eng = engine
  )
  
  # Generate Workflow object
  mod_tbl <- mod_spec_tbl %>%
    dplyr::mutate(
      wflw = internal_make_wflw(mod_spec_tbl, .rec_obj = rec_obj)
    )
  
  mod_fitted_tbl <- mod_tbl %>%
    dplyr::mutate(
      fitted_wflw = internal_make_fitted_wflw(mod_tbl, splits_obj)
    )

  mod_pred_tbl <- mod_fitted_tbl %>%
    dplyr::mutate(
      pred_wflw = internal_make_wflw_predictions(mod_fitted_tbl, splits_obj)
    )
    
  
  # Return ----
  class(mod_tbl) <- c("fst_reg_tbl", class(mod_tbl))
  attr(mod_tbl, ".parsnip_engines") <- .parsnip_eng
  attr(mod_tbl, ".parsnip_functions") <- .parsnip_fns
  attr(mod_tbl, ".split_type") <- .split_type
  attr(mod_tbl, ".split_args") <- .split_args
  
  return(mod_pred_tbl)
}

Example:

> library(recipes, quietly = TRUE)
> library(dplyr, quietly = TRUE)
> 
> rec_obj <- recipe(mpg ~ ., data = mtcars)
> frt_tbl <- fast_regression(mtcars, rec_obj, .parsnip_eng = c("lm","glm"),
+ .parsnip_fns = "linear_reg")
> glimpse(frt_tbl)
Rows: 2
Columns: 8
$ .model_id       <int> 1, 2
$ .parsnip_engine <chr> "lm", "glm"
$ .parsnip_mode   <chr> "regression", "regression"
$ .parsnip_fns    <chr> "linear_reg", "linear_reg"
$ model_spec      <list> [~NULL, ~NULL, NULL, regression, TRUE, NULL, lm, TRUE], [~NULL, …
$ wflw            <list> [cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb, mpg, double,…
$ fitted_wflw     <list> [cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb, mpg, double,…
$ pred_wflw       <list> [<tbl_df[24 x 1]>], [<tbl_df[24 x 1]>]

Function `argument_matcher()`

Match provided arguments to the function arguments dynamically.

Function:

argument_matcher <- function(.f = "linear_reg", .args = list()){
  
  # TidyEval ----
  fns <- as.character(.f)
  
  fns_args <- formalArgs(fns)
  fns_args_list <- as.list(fns_args)
  names(fns_args_list) <- fns_args
  
  arg_list <- .args
  arg_list_names <- unique(names(arg_list))
  
  l <- list(arg_list, fns_args_list)
  
  arg_idx <- which(arg_list_names %in% fns_args_list)
  bad_arg_idx <- which(!arg_list_names %in% fns_args_list)
  
  bad_args <- arg_list[bad_arg_idx]
  bad_arg_names <- unique(names(bad_args))
  
  final_args <- arg_list[arg_idx]
  
  # Return ----
  if (length(bad_arg_names > 0)){
    rlang::inform(
      message = paste0("bad arguments passed: ", bad_arg_names),
      use_cli_format = TRUE
    )
  }

  return(final_args)
}

Examples:

argument_matcher(.args = list(mode = "regression", engine = "lm"))
args_to_parsnip <- argument_matcher(.args = list(mode = "regression", engine = "lm"))
> do.call("linear_reg", args_to_parsnip)
Linear Regression Model Specification (regression)

Computational engine: lm 

> argument_matcher(.args = list(mode = "regression", engine = "lm", trees = 1, mtry = 1))
$mode
[1] "regression"

$engine
[1] "lm"

bad arguments passed: trees
bad arguments passed: mtry 

Update `fast_regression_parsnip_spec_tbl()` function

Update the arguments and params, let's keep them standardized.

Function:

fast_regression_parsnip_spec_tbl <- function(.parsnip_fns = "all",
                                             .parsnip_eng = "all") {
  
  # Thank you https://stackoverflow.com/questions/74691333/build-a-tibble-of-parsnip-model-calls-with-match-fun/74691529#74691529
  # Tidyeval ----
  call <- list(.parsnip_fns) %>%
    purrr::flatten_chr()
  engine <- list(.parsnip_eng) %>%
    purrr::flatten_chr() 
  
  # Make tibble
  mod_tbl <- dplyr::tibble(
    .parsnip_engine = c(
      "lm",
      "brulee",
      "gee",
      "glm",
      "glmer",
      "glmnet",
      "gls",
      "h2o",
      "keras",
      "lme",
      "lmer",
      "spark",
      "stan",
      "stan_glmer",
      "Cubist",
      "glm",
      "gee",
      "glmer",
      "glmnet",
      "h2o",
      "hurdle",
      "stan",
      "stan_glmer",
      "zeroinfl",
      "survival",
      "flexsurv",
      "flexsurvspline"
    ),
    .parsnip_mode = c(
      rep("regression", 24),
      rep("censored regression", 3)
    ),
    .parsnip_fns = c(
      rep("linear_reg", 14),
      "cubist_rules",
      rep("poisson_reg",9),
      rep("survival_reg", 3)
    )
  )
  
  # Filter ----
  if (!"all" %in% engine){
    mod_tbl <- mod_tbl %>%
      dplyr::filter(.parsnip_engine %in% engine)
  }
  
  if (!"all" %in% call){
    mod_tbl <- mod_tbl %>%
      dplyr::filter(.parsnip_fns %in% call)
  }
  
  mod_filtered_tbl <- mod_tbl
  
  mod_spec_tbl <- mod_filtered_tbl %>%
    dplyr::mutate(
      .model_spec = purrr::pmap(
        dplyr::cur_data(),
        ~ match.fun(..3)(mode = ..2, engine = ..1)
        #~ get(..3)(mode = ..2, engine = ..1)
      )
    )
  
  # Return ----
  class(mod_spec_tbl) <- c("fst_reg_spec_tbl", class(mod_spec_tbl))
  attr(mod_spec_tbl, ".parsnip_engines") <- .parsnip_eng
  attr(mod_spec_tbl, ".parsnip_functions") <- .parsnip_fns
  
  return(mod_spec_tbl)
  
}

Example:

> fast_regression_parsnip_spec_tbl(.parsnip_eng = c("lm","glm"))
# A tibble: 3 × 4
  .parsnip_engine .parsnip_mode .parsnip_fns .model_spec
  <chr>           <chr>         <chr>        <list>     
1 lm              regression    linear_reg   <spec[+]>  
2 glm             regression    linear_reg   <spec[+]>  
3 glm             regression    poisson_reg  <spec[+]>  

fitted_wflw

Function:

# Safely make fitted workflow
internal_make_fitted_wflw <- function(.model_tbl, .splits_obj){
  
  # Tidyeval ----
  model_tbl <- .model_tbl
  splits_obj <- .splits_obj
  col_nms <- colnames(model_tbl)
  
  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }
  
  if (!"wflw" %in% col_nms){
    rlang::abort(
      message = "Missing the column 'wflw'",
      use_cli_format = TRUE
    )
  }
  
  if (!".model_id" %in% col_nms){
    rlang::abort(
      message = "Missing the column '.model_id'",
      use_cli_format = TRUE
    )
  }
  
  # Manipulation
  # Make a group split object list
  models_list <- model_tbl %>%
    dplyr::group_split(.model_id)
  
  # Make the fitted workflow object using purrr imap
  fitted_wflw_list <- models_list %>%
    purrr::imap(
      .f = function(obj, id){
        
        # Pull the workflow column and then pluck it
        wflw <- obj %>% dplyr::pull(6) %>% pluck(1)
        
        # Create a safe parsnip::fit function
        safe_parsnip_fit <- purrr::safely(
          parsnip::fit,
          otherwise = "Error - Could not fit the workflow.",
          quiet = FALSE
        )
        
        # Return the fitted workflow
        ret <- safe_parsnip_fit(
          wflw, data = rsample::training(splits_obj$splits)
        )
        
        res <- ret %>% purrr::pluck("result")
        
        return(res)
      }
    )
    
  return(fitted_wflw_list)
  
}

Example:

> internal_make_fitted_wflw(mod_tbl, splits_obj)
[[1]]
══ Workflow [trained] ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ──────────────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ─────────────────────────────────────────────────────────────────────────────────

Call:
stats::lm(formula = ..y ~ ., data = data)

Coefficients:
(Intercept)          cyl         disp           hp         drat           wt  
  28.907650    -0.664742    -0.009334    -0.014871     0.197659    -0.188327  
       qsec           vs           am         gear         carb  
  -0.190551     0.132323     1.732139     1.372764    -1.184251  


[[2]]
══ Workflow [trained] ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ──────────────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ─────────────────────────────────────────────────────────────────────────────────

Call:  stats::glm(formula = ..y ~ ., family = stats::gaussian, data = data)

Coefficients:
(Intercept)          cyl         disp           hp         drat           wt  
  28.907650    -0.664742    -0.009334    -0.014871     0.197659    -0.188327  
       qsec           vs           am         gear         carb  
  -0.190551     0.132323     1.732139     1.372764    -1.184251  

Degrees of Freedom: 23 Total (i.e. Null);  13 Residual
Null Deviance:	    736.5 
Residual Deviance: 100.5 	AIC: 126.5

`internal_make_spec_tbl()`

Current:

if (!".parsnip_engine" %in% nms | !".parsnip_mode" %in% nms | !".parsnip_fns" %in% nms){
    rlang::abort(
      message = "The model tibble must come from the class/reg to parsnip function.",
      use_cli_format = TRUE
    )
  }

New:

if (!inherits(df, "tidyaml_base_tbl")){
rlang::abort(
      message = "The model tibble must come from the class/reg to parsnip function.",
      use_cli_format = TRUE
    )
  }

pred_wflw

Function:

# Safely make predictions on fitted workflow
internal_make_wflw_predictions <- function(.model_tbl, .splits_obj){
  
  # Tidyeval ----
  model_tbl <- .model_tbl
  splits_obj <- .splits_obj
  col_nms <- colnames(model_tbl)
  
  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }
  
  if (!"fitted_wflw" %in% col_nms){
    rlang::abort(
      message = "Missing the column 'wflw'",
      use_cli_format = TRUE
    )
  }
  
  if (!".model_id" %in% col_nms){
    rlang::abort(
      message = "Missing the column '.model_id'",
      use_cli_format = TRUE
    )
  }
  
  # Manipulation
  # Make a group split object list
  model_factor_tbl <- model_tbl %>%
    dplyr::mutate(.model_id = forcats::as_factor(.model_id))
  
  models_list <- model_factor_tbl %>%
    dplyr::group_split(.model_id)
  
  # Make the predictions on the fitted workflow object using purrr imap
  wflw_preds_list <- models_list %>%
    purrr::imap(
      .f = function(obj, id){
        
        # Pull the fitted workflow column and then pluck it
        fitted_wflw = obj %>% dplyr::pull(7) %>% pluck(1)
        
        # Create a safe stats::predict
        safe_stats_predict <- purrr::safely(
          stats::predict,
          otherwise = "Error - Could not make predictions",
          quiet = FALSE
        )
        
        # Return the predictions
        ret <- safe_stats_predict(
          fitted_wflw, 
          new_data = rsample::training(splits_obj$splits)
        )
        
        res <- ret %>% purrr::pluck("result")
        
        return(res)
      }
    )
  
  return(wflw_preds_list)
}

Example:

> internal_make_wflw_predictions(mod_fitted_tbl, splits_obj)
Error: no applicable method for 'predict' applied to an object of class "character"
[[1]]
[[1]]$result
# A tibble: 24 × 1
   .pred
   <dbl>
 1  23.2
 2  18.9
 3  15.4
 4  17.7
 5  15.6
 6  16.8
 7  15.5
 8  19.7
 9  11.7
10  22.6
# … with 14 more rows
# ℹ Use `print(n = ...)` to see more rows

[[1]]$error
NULL


[[2]]
[[2]]$result
[1] "Error - Could not make predictions"

[[2]]$error
<simpleError in UseMethod("predict"): no applicable method for 'predict' applied to an object of class "character">


[[3]]
[[3]]$result
# A tibble: 24 × 1
   .pred
   <dbl>
 1  23.2
 2  18.9
 3  15.4
 4  17.7
 5  15.6
 6  16.8
 7  15.5
 8  19.7
 9  11.7
10  22.6
# … with 14 more rows
# ℹ Use `print(n = ...)` to see more rows

[[3]]$error
NULL
``

Make base `regression` and `classification` tibble functions

These will act as the base table makers for fast_regression/classification etc.

This means we can update the functions fast_classification_parsnip_spec_tbl(), and fast_regression_parsnip_spec_tbl() to use these underlying functions for the mod_spec_tbl variable.

Function:

make_regression_base_tbl <- function(){
  
  # Make the regression tribble
  mod_tbl <- tibble::tribble(
    ~.parsnip_engine, ~.parsnip_mode, ~.parsnip_fns,
    "lm", "regression", "linear_reg",
    "brulee", "regression", "linear_reg",
    "gee", "regression", "linear_reg",
    "glm","regression","linear_reg",
    "glmer","regression","linear_reg",
    "glmnet","regression","linear_reg",
    "gls","regression","linear_reg",
    "lme","regression","linear_reg",
    "lmer","regression","linear_reg",
    "stan","regression","linear_reg",
    "stan_glmer","regression","linear_reg",
    "Cubist","regression","cubist_rules",
    "glm","regression","poisson_reg",
    "gee","regression","poisson_reg",
    "glmer","regression","poisson_reg",
    "glmnet","regression","poisson_reg",
    "hurdle","regression","poisson_reg",
    "stan","regression","poisson_reg",
    "stan_glmer","regression","poisson_reg",
    "zeroinfl","regression","poisson_reg",
    "survival","censored regression","survival_reg",
    "flexsurv","censored regression","survival_reg",
    "slexsurvspline","censored regression","survival_reg",
    "earth","regression","bag_mars",
    "rpart","regression","bag_mars",
    "dbarts","regression","bart",
    "xgboost","regression","boost_tree",
    "lightgbm","regression","boost_tree",
    "mboost","censored regression","boost_tree",
    "rpart","regression","decision_tree",
    "partykit","regression","decision_tree",
    "mgcv","regression","gen_additive_mod",
    "earth","regression","mars",
    "nnet","regression","mlp",
    "brulee","regression","mlp",
    "kknn","regression","nearest_neighbor",
    "mixOmics","regression","pls",
    "ranger","regression","rand_forest",
    "randomForest","regression","rand_forest",
    "partykit","censored regression","rand_forest",
    "aorsf","censored regression","rand_forest",
    "xrf","regression","rule_fit",
    "LiblineaR","regression","svm_linear",
    "kernlab","regression","svm_linear",
    "kernlab","regression","svm_poly",
    "kernlab","regression","svm_rbf"
  )
  
  # Return
  return(mod_tbl)
}

make_classification_base_tbl <- function(){
  
  # Make the regression tribble
  # Make tibble
  mod_tbl <- tibble::tribble(
    ~.parsnip_engine, ~.parsnip_mode, ~.parsnip_fns,
    "earth","classification","bag_mars",
    "earth","classification","discrim_flexible",
    "dbarts","classification","bart",
    "MASS","classification","discrim_linear",
    "mda","classification","discrim_linear",
    "sda","classification","discrim_linear",
    "sparsediscrim","classification","discrim_linear",
    "MASS","classification","discrim_quad",
    "sparsediscrim","classification","discrim_quad",
    "klaR","classification","discrim_regularized",
    "mgcv","classification","gen_additive_mod",
    "brulee","classification","logistic_reg",
    "gee","classification","logistic_reg",
    "glm","classification","logistic_reg",
    "glmer","classification","logistic_reg",
    "glmnet","classification","logistic_reg",
    "LiblineaR","classification","logistic_reg",
    "earth","classification","mars",
    "brulee","classification","mlp",
    "nnet","classification","mlp",
    "brulee","classification","multinom_reg",
    "glmnet","classification","multinom_reg",
    "nnet","classification","multinom_reg",
    "klaR","classification","naive_Bayes",
    "kknn","classification","nearest_neighbor",
    "mixOmics","classification","pls",
    "xrf","classification","rule_fit",
    "kernlab","classification","svm_linear",
    "LiblineaR","classification","svm_linear",
    "kernlab","classification","svm_poly",
    "kernlab","classification","svm_rbf",
    "liquidSVM","classification","svm_rbf"
  )
  
  # Return
  return(mod_tbl)
}

Examples:

> make_regression_base_tbl()
# A tibble: 46 × 3
   .parsnip_engine .parsnip_mode .parsnip_fns
   <chr>           <chr>         <chr>       
 1 lm              regression    linear_reg  
 2 brulee          regression    linear_reg  
 3 gee             regression    linear_reg  
 4 glm             regression    linear_reg  
 5 glmer           regression    linear_reg  
 6 glmnet          regression    linear_reg  
 7 gls             regression    linear_reg  
 8 lme             regression    linear_reg  
 9 lmer            regression    linear_reg  
10 stan            regression    linear_reg  
# … with 36 more rows
# ℹ Use `print(n = ...)` to see more rows

> make_classification_base_tbl()
# A tibble: 32 × 3
   .parsnip_engine .parsnip_mode  .parsnip_fns       
   <chr>           <chr>          <chr>              
 1 earth           classification bag_mars           
 2 earth           classification discrim_flexible   
 3 dbarts          classification bart               
 4 MASS            classification discrim_linear     
 5 mda             classification discrim_linear     
 6 sda             classification discrim_linear     
 7 sparsediscrim   classification discrim_linear     
 8 MASS            classification discrim_quad       
 9 sparsediscrim   classification discrim_quad       
10 klaR            classification discrim_regularized
# … with 22 more rows
# ℹ Use `print(n = ...)` to see more rows

workflowsets object exports

Create a workflow set object from a model_spec object specifically from an object with class of tidyaml_base_tbl

Release tidyAML 0.0.1

First release:

Prepare for release:

  • git pull
  • devtools::build_readme()
  • urlchecker::url_check()
  • devtools::check(remote = TRUE, manual = TRUE)
  • devtools::check_win_devel()
  • rhub::check_for_cran()
  • git push

Submit to CRAN:

  • usethis::use_version('patch')
  • devtools::submit_cran()
  • Approve email

Wait for CRAN...

  • Accepted 🎉
  • git push
  • usethis::use_github_release()
  • usethis::use_dev_version()
  • git push

Make `fast_regression()` function

Function:

fast_regression <- function(.data, .rec_obj, .parsnip_fns = "all",
                            .parsnip_eng = "all", .split_type = "initial_split",
                            .split_args = NULL){
  
  # Tidy Eval ----
  call <- list(.parsnip_fns) %>%
    purrr::flatten_chr()
  engine <- list(.parsnip_eng) %>%
    purrr::flatten_chr()
  
  rec_obj <- .rec_obj
  split_type <- .split_type
  split_args <- .split_args
  
  # Checks ----
  
  # Get data splits
  df <- dplyr::as_tibble(.data)
  splits_obj <- create_splits(
    .data = df, 
    .split_type = split_type,
    .split_args = split_args
  )
  
  # Generate Model Spec Tbl
  mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
    .parsnip_fns = call,
    .parsnip_eng = engine
  )
  
  mod_rec_tbl <- mod_spec_tbl %>%
    dplyr::mutate(.model_recipe = list(rec_obj))
  
  mod_tbl <- mod_rec_tbl %>%
    dplyr::mutate(
      .wflw = list(
        workflows::workflow() %>%
          workflows::add_recipe(.model_recipe[[1]]) %>%
          workflows::add_model(.model_spec[[1]])
      )
    ) %>%
    dplyr::mutate(
      .fitted_wflw = list(
        parsnip::fit(.wflw[[1]], data = rsample::training(splits_obj$splits))
      )
    ) %>%
    dplyr::mutate(
      .pred_wflw = list(
        predict(.fitted_wflw[[1]], new_data = rsample::testing(splits_obj$splits))
      )
    )
  
  # Return ----
  class(mod_tbl) <- c("fst_reg_tbl", class(mod_tbl))
  attr(mod_tbl, ".parsnip_engines") <- .parsnip_eng
  attr(mod_tbl, ".parsnip_functions") <- .parsnip_fns
  attr(mod_tbl, ".split_type") <- .split_type
  attr(mod_tbl, ".split_args") <- .split_args
  return(mod_tbl)
}

Example:

> rec_obj <- recipes::recipe(mpg ~ ., data = mtcars)
> frt_tbl <- fast_regression(mtcars, rec_obj, .parsnip_eng = c("lm","glm"))
> frt_tbl
# A tibble: 3 × 8
  .parsnip_engine .parsn…¹ .pars…² .model_…³ .model…⁴ .wflw      .fitted_…⁵ .pred_…⁶
  <chr>           <chr>    <chr>   <list>    <list>   <list>     <list>     <list>  
1 lm              regresslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
2 glm             regresslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
3 glm             regresspoisso<spec[+]> <recipe> <workflow> <workflow> <tibble>
# … with abbreviated variable names ¹​.parsnip_mode, ²​.parsnip_fns, ³​.model_spec,
#   ⁴​.model_recipe, ⁵​.fitted_wflw, ⁶​.pred_wflw
> class(frt_tbl)
[1] "fst_reg_tbl"      "fst_reg_spec_tbl" "tbl_df"           "tbl"             
[5] "data.frame"      
> attributes(frt_tbl)
$names
[1] ".parsnip_engine" ".parsnip_mode"   ".parsnip_fns"    ".model_spec"    
[5] ".model_recipe"   ".wflw"           ".fitted_wflw"    ".pred_wflw"     

$row.names
[1] 1 2 3

$class
[1] "fst_reg_tbl"      "fst_reg_spec_tbl" "tbl_df"           "tbl"             
[5] "data.frame"      

$.parsnip_engines
[1] "lm"  "glm"

$.parsnip_functions
[1] "all"

$.split_type
[1] "initial_split"

> 
> frt_tbl$.fitted_wflw[[1]] %>%
+   broom::glance()
# A tibble: 1 × 12
  r.squared adj.r.s…¹ sigma stati…² p.value    df logLik   AIC   BIC devia…³ df.re…⁴
      <dbl>     <dbl> <dbl>   <dbl>   <dbl> <dbl>  <dbl> <dbl> <dbl>   <dbl>   <int>
1     0.880     0.788  2.58    9.55 1.76e-4    10  -49.5  123.  137.    86.8      13
# … with 1 more variable: nobs <int>, and abbreviated variable names
#   ¹​adj.r.squared, ²​statistic, ³​deviance, ⁴​df.residual
# ℹ Use `colnames()` to see all variable names
> 
> frt_tbl$.fitted_wflw[[1]] %>%
+   broom::tidy()
# A tibble: 11 × 5
   term        estimate std.error statistic p.value
   <chr>          <dbl>     <dbl>     <dbl>   <dbl>
 1 (Intercept)  14.3      20.6       0.692    0.501
 2 cyl          -0.507     1.14     -0.443    0.665
 3 disp          0.0119    0.0252    0.473    0.644
 4 hp           -0.0279    0.0249   -1.12     0.282
 5 drat          1.72      1.89      0.912    0.379
 6 wt           -2.99      2.34     -1.28     0.223
 7 qsec          0.670     0.786     0.852    0.409
 8 vs           -0.391     2.43     -0.161    0.874
 9 am            2.27      2.22      1.02     0.326
10 gear         -0.0871    1.59     -0.0549   0.957
11 carb          0.257     0.957     0.269    0.792

tidyAML in the README Example: only 24 predictions output - mtcars has 32 rows ?

Hi Steven,

tidyAML
sounds like a real time-saving idea.

Q:
in the GIT README file example,
you mention:
"Let’s look at the fitted workflow predictions...".

The example shows 24 mpg predictions
in the output for each model in:

  • frt_tbl$pred_wflw

But doesn't mtcars have 32 rows (cars)?

Hope you can guide me here, Steven.
Really want to understand
how to use tidyAML ,
in the simplest possible terms :-).

SFd99
San Francisco
Linux, Rstudio
latest tidyAML PKG f/CRAN.

Make function `fast_classification()`

Function:

#' Generate Model Specification calls to `parsnip`
#'
#' @family Model_Generator
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @details With this function you can generate a tibble output of any classification
#' model specification and it's fitted `workflow` object. Per recipes documentation 
#' explicitly with `step_string2factor()` it is encouraged to mutate your predictor
#' into a factor before you create your recipe.
#'
#' @description Creates a list/tibble of parsnip model specifications.
#'
#' @param .data The data being passed to the function for the classification problem
#' @param .rec_obj The recipe object being passed.
#' @param .parsnip_fns The default is 'all' which will create all possible
#' classification model specifications supported.
#' @param .parsnip_eng the default is 'all' which will create all possible
#' classification model specifications supported.
#' @param .split_type The default is 'initial_split', you can pass any type of
#' split supported by `rsample`
#' @param .split_args The default is NULL, when NULL then the default parameters
#' of the split type will be executed for the rsample split type.
#'
#' @examples
#' library(recipes, quietly = TRUE)
#' library(dplyr, quietly = TRUE)
#' 
#' df <- mtcars %>% mutate(cyl = as.factor(cyl))
#' rec_obj <- recipe(cyl ~ ., data = df)
#'   
#' fct_tbl <- fast_classification(
#'   .data = df, 
#'   .rec_obj = rec_obj, 
#'   .parsnip_eng = c("glm","LiblineaR"))
#'   
#' glimpse(fct_tbl)
#'
#' @return
#' A list or a tibble.
#'
#' @name fast_classification
NULL

#' @export
#' @rdname fast_classification
#'

fast_classification <- function(.data, .rec_obj, .parsnip_fns = "all",
                            .parsnip_eng = "all", .split_type = "initial_split",
                            .split_args = NULL){
  
  # Tidy Eval ----
  call <- list(.parsnip_fns) %>%
    purrr::flatten_chr()
  engine <- list(.parsnip_eng) %>%
    purrr::flatten_chr()
  
  rec_obj <- .rec_obj
  split_type <- .split_type
  split_args <- .split_args
  
  # Checks ----
  
  # Get data splits
  df <- dplyr::as_tibble(.data)
  splits_obj <- create_splits(
    .data = df,
    .split_type = split_type,
    .split_args = split_args
  )
  
  # Generate Model Spec Tbl
  mod_spec_tbl <- fast_classification_parsnip_spec_tbl(
    .parsnip_fns = call,
    .parsnip_eng = engine
  )
  
  # Generate Workflow object
  mod_tbl <- mod_spec_tbl %>%
    dplyr::mutate(
      wflw = internal_make_wflw(mod_spec_tbl, .rec_obj = rec_obj)
    )
  
  mod_fitted_tbl <- mod_tbl %>%
    dplyr::mutate(
      fitted_wflw = internal_make_fitted_wflw(mod_tbl, splits_obj)
    )
  
  mod_pred_tbl <- mod_fitted_tbl %>%
    dplyr::mutate(
      pred_wflw = internal_make_wflw_predictions(mod_fitted_tbl, splits_obj)
    )
  
  
  # Return ----
  class(mod_tbl) <- c("fst_reg_tbl", class(mod_tbl))
  attr(mod_tbl, ".parsnip_engines") <- .parsnip_eng
  attr(mod_tbl, ".parsnip_functions") <- .parsnip_fns
  attr(mod_tbl, ".split_type") <- .split_type
  attr(mod_tbl, ".split_args") <- .split_args
  
  return(mod_pred_tbl)
}

Example:

> library(recipes, quietly = TRUE)
> library(dplyr, quietly = TRUE)
> 
> df <- mtcars %>% mutate(cyl = as.factor(cyl))
> rec_obj <- recipe(cyl ~ ., data = df)
>   
> fct_tbl <- fast_classification(
+   .data = df, 
+   .rec_obj = rec_obj, 
+   .parsnip_eng = c("glm","LiblineaR"))
Warning message:
Problem while computing `fitted_wflw = internal_make_fitted_wflw(mod_tbl, splits_obj)`.glm.fit: fitted probabilities numerically 0 or 1 occurred 
>   
> glimpse(fct_tbl)
Rows: 3
Columns: 8
$ .model_id       <int> 1, 2, 3
$ .parsnip_engine <chr> "glm", "LiblineaR", "LiblineaR"
$ .parsnip_mode   <chr> "classification", "classification", "classification"
$ .parsnip_fns    <chr> "logistic_reg", "logistic_reg", "svm_linear"
$ model_spec      <list> [~NULL, ~NULL, NULL, classification, TRUE, NULL, glm, TRUE], [~N$ wflw            <list> [mpg, disp, hp, drat, wt, qsec, vs, am, gear, carb, cyl, double,…
$ fitted_wflw     <list> [mpg, disp, hp, drat, wt, qsec, vs, am, gear, carb, cyl, double,…
$ pred_wflw       <list> [<tbl_df[24 x 1]>], [<tbl_df[24 x 1]>], [<tbl_df[24 x 1]>]

Create function `internal_make_tuned_wflw()`

Function:

#' Internals Make a Tunable Model Specification
#'
#' @family Internals
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @description Make a tuned model specification object.
#'
#' @details This will take a model specification that is created from a function
#' like [tidyAML::fast_regression_parsnip_spec_tbl()] and update the __model_spec__
#' `args` to `tune::tune()`. This is done dynamically, meaning you do not need
#' to know the names of the parameters inside of the model specification.
#'
#' @param .model_tbl The model table that is generated from a function like
#' `fast_regression_parsnip_spec_tbl()`, must have a class of "tidyaml_mod_spec_tbl".
#'
#' @examples
#' library(dplyr)
#' 
#' mod_tbl <- fast_regression_parsnip_spec_tbl()
#' mod_tbl$model_spec[[1]]
#' 
#' updated_tbl <- mod_tbl %>%
#'   mutate(model_spec = internal_set_args_to_tune(mod_tbl))
#' updated_tbl$model_spec[[1]]
#'
#' @return
#' A list object of workflows.
#'
#' @name internal_set_args_to_tune
NULL

#' @export
#' @rdname internal_set_args_to_tune

internal_set_args_to_tune <- function(.model_tbl){
  
  # Tidyeval
  model_tbl <- .model_tbl
  
  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }
  
  model_tbl_with_params <- mod_tbl %>% 
    dplyr::mutate(
      model_params = purrr::pmap(
        dplyr::cur_data(), 
        ~ list(formalArgs(..4))
      ) 
    )
  
  models_list_new <- model_tbl_with_params %>%
    dplyr::group_split(.model_id)
  
  tuned_params_list <- models_list_new %>%
    purrr::imap(
      .f = function(obj, id){
        
        # Pull the model params
        mod_params <- obj %>% dplyr::pull(6) %>% purrr::pluck(1) # change to pull(6)
        mod_params_list <- unlist(mod_params) %>% as.list()
        #param_names <- unlist(mod_params)
        names(mod_params_list) <- unlist(mod_params)
        
        # Set mode and engine
        p_mode <- obj %>% dplyr::pull(3) %>% purrr::pluck(1)
        p_engine <- obj %>% dplyr::pull(2) %>% purrr::pluck(1)
        me_list <- list(
          mode = paste0("mode = ", p_mode),
          engine = paste0("engine = ", p_engine)
        )
        
        # Get all other params
        me_vec <- c("mode","engine")
        pv <- unlist(mod_params)
        params_to_modify <- pv[!pv %in% me_vec] %>% as.list()
        names(params_to_modify) <- unlist(params_to_modify)
        
        # Set each item equal to .x = tune::tune()
        tuned_params_list <- purrr::map(
          params_to_modify,
          ~ paste0("tune::tune()")
        )
        
        # use modifyList()
        res <- utils::modifyList(mod_params_list, tuned_params_list)
        res <- utils::modifyList(res, me_list)
        
        # Return      
        return(res)
        
      }
    )
  
  models_with_params_list <- map2(
    .x = tuned_params_list, 
    .y = models_list_new, 
    ~ {.y$model_params <- list(.x[.y$model_params[[1]][[1]]]);.y}
  )
  
  new_mod_obj <- models_with_params_list %>%
    imap(
      .f = function(obj, id){
        
        # Get Model Specification
        mod_spec <- obj %>%
          dplyr::pull(5) %>%
          purrr::pluck(1)
        
        # Get the tuned params
        new_mod_args <- obj %>%
          dplyr::pull(6) %>%
          purrr::pluck(1)
        
        # Drop the ones we don't need to set
        new_mod_args <- new_mod_args %>%
          unlist() %>%
          subset(!names(.) %in% c('mode','engine')) %>%
          as.list()
        
        # Set the new model arguments
        mod_spec$args <- new_mod_args
        
        # Return the newly modified model specification
        return(mod_spec)
      }
    )
  
  return(new_mod_obj)
  
}

Example:

library(dplyr)

> mod_tbl <- fast_regression_parsnip_spec_tbl()
> mod_tbl$model_spec[[1]]
Linear Regression Model Specification (regression)

Computational engine: lm 

> 
> updated_tbl <- mod_tbl %>%
+   mutate(model_spec = internal_set_args_to_tune(mod_tbl))
> updated_tbl$model_spec[[1]]
Linear Regression Model Specification (regression)

Main Arguments:
  penalty = tune::tune()
  mixture = tune::tune()

Computational engine: lm 

wflw

Function:

# Safely make workflow
internal_make_wflw <- function(.model_tbl, .rec_obj){
  
  # Tidyeval ----
  model_tbl <- .model_tbl
  rec_obj <- .rec_obj
  
  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }
  
  # Manipulation
  model_factor_tbl <- model_tbl %>%
    dplyr::mutate(.model_id = forcats::as_factor(.model_id)) %>%
    dplyr::mutate(rec_obj = list(rec_obj))
  
  # Make a group split object list
  models_list <- model_factor_tbl %>%
    dplyr::group_split(.model_id)
  
  # Make the Workflow Object using purrr imap
  wflw_list <- models_list %>%
    purrr::imap(
      .f = function(obj, id){
        
        # Pull the model column and then pluck the model
        mod <- obj %>% dplyr::pull(5) %>% purrr::pluck(1)
        
        # PUll the recipe column and then pluck the recipe
        rec_obj <- obj %>% dplyr::pull(6) %>% purrr::pluck(1)
        
        # Create a safe add_model function
        safe_add_model <- purrr::safely(
          workflows::add_model,
          otherwise = "Error - Could not make workflow object.",
          quiet = FALSE
        )
        
        # Return the workflow object with recipe and model
        ret <- workflows::workflow() %>%
          workflows::add_recipe(rec_obj) %>%
          safe_add_model(mod)
        
        # Pluck the result
        res <- ret %>% purrr::pluck("result")
        
        # Return the result
        return(res)
      }
    )
  
  
  # Return
  return(wflw_list)
}

Example:

library(tidyverse)
library(tidymodels)
tidymodels_prefer()

mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
  .parsnip_fns = "linear_reg", 
  .parsnip_eng = c("lm","glm")
)

# A tibble: 2 × 5
  .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec
      <int> <chr>           <chr>         <chr>        <list>    
1         1 lm              regression    linear_reg   <spec[+]> 
2         2 glm             regression    linear_reg   <spec[+]> 

# Generate Workflow object
mod_tbl <- mod_spec_tbl %>%
  dplyr::mutate(
    wflw = internal_make_wflw(mod_spec_tbl, .rec_obj = rec_obj)
  )

> mod_tbl
# A tibble: 2 × 6
  .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec wflw      
      <int> <chr>           <chr>         <chr>        <list>     <list>    
1         1 lm              regression    linear_reg   <spec[+]>  <workflow>
2         2 glm             regression    linear_reg   <spec[+]>  <workflow>

Function install_all_extensions()

Function:

core_packages <- function(){
  c(
    "multilevelmod","rules","poissonreg","censored","baguette","bonsai",
    "plsmod","brulee","rstanarm","dbarts","kknn","ranger","randomForest",
    "LiblineaR","rule_fit","mixOmics"
  )
}

install_deps <- function(){
  
  ans <- utils::menu(c("Yes","No"), title = "Do you want to install all of the dependencies?")
  
  pkgs <- core_packages()
  
  if (ans == 1){
    
    # Loop through each name 
    for (lib in pkgs){
      
      # check if already installed
      if (!require(lib, character.only = TRUE)){
        
        # If the library is not installed then install it
        install.packages(lib, dependencies = TRUE)
      }
    }
  }
  
}

load_deps <- function(){
  
  ans <- utils::menu(c("Yes","No"), title = "Do you want to load core dependencies?")
  
  pkgs <- core_packages()
  
  pkgs_unloaded <- function(){
    search <- paste0("package:", pkgs)
    pkgs[!search %in% search()]
  }
  
  same_lib <- function(pkg){
    loc <- if (pkg %in% loadedNamespaces()) dirname(getNamespaceInfo(pkg, "path"))
    library(pkg, lib.loc = loc, character.only = TRUE, warn.conflicts = FALSE)
  }
  
  tidyaml_pkg_attach <- function(){
    to_load <- pkgs_unloaded()
    
    suppressPackageStartupMessages(lapply(to_load, same_lib))
    
    invisible(to_load)
  }
  
  if (ans == 1){
    
    pkgs_unloaded()
    tidyaml_pkg_attach()
    
  }
}

Example:

> core_packages()
 [1] "multilevelmod" "aqua"          "rules"         "poissonreg"    "censored"     
 [6] "baguette"      "bonsai"        "plsmod"        "brulee"        "rstanarm"     
[11] "dbarts"        "kknn"          "ranger"        "randomForest"  "LiblineaR"    
[16] "rule_fit"     

 install_deps()
Do you want to install all of the dependencies 

1: Yes
2: No

Selection: 2

> load_deps()
Do you want to load core dependencies? 

1: Yes
2: No

Selection: 2

Drop support for aqua, spark and keras

These require their own host of installations and special treatments. Not necessary for an initial release.

Here is the updated version of the fast_regression_parsnip_spec_tbl(), it makes use of the tibble::tribble() function in order to more easily add/subtract in the future.

Function:

#' Utility Regression call to `parsnip`
#'
#' @family Utility
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @details Creates a tibble of parsnip regression model specifications. This will
#' create a tibble of 58 different regression model specifications which can be
#' filtered. The model specs are created first and then filtered out. This will
#' only create models for __regression__ problems. To find all of the supported
#' models in this package you can visit \url{https://www.tidymodels.org/find/parsnip/}
#'
#' @seealso \url{https://parsnip.tidymodels.org/reference/linear_reg.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/cubist_rules.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/survival_reg.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/poisson_reg.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/bag_mars.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/bag_tree.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/bart.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/boost_tree.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/decision_tree.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/gen_additive_mod.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/mars.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/mlp.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/nearest_neighbor.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/pls.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/rand_forest.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/rule_fit.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/svm_linear.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/svm_poly.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/svm_rbf.html}
#'
#' @description Creates a tibble of parsnip regression model specifications.
#'
#' @param .parsnip_fns The default for this is set to `all`. This means that all
#' of the parsnip __linear regression__ functions will be used, for example `linear_reg()`,
#' or `cubist_rules`. You can also choose to pass a c() vector like `c("linear_reg","cubist_rules")`
#' @param .parsnip_eng The default for this is set to `all`. This means that all
#' of the parsnip __linear regression engines__ will be used, for example `lm`, or
#' `glm`. You can also choose to pass a c() vector like `c('lm', 'glm')`
#'
#' @examples
#' fast_regression_parsnip_spec_tbl(.parsnip_fns = "linear_reg")
#' fast_regression_parsnip_spec_tbl(.parsnip_eng = c("lm","glm"))
#'
#' @return
#' A tibble with an added class of 'fst_reg_spec_tbl'
#'
#' @importFrom parsnip linear_reg cubist_rules poisson_reg survival_reg
#'
#' @name fast_regression_parsnip_spec_tbl
NULL

#' @export
#' @rdname fast_regression_parsnip_spec_tbl

fast_regression_parsnip_spec_tbl <- function(.parsnip_fns = "all",
                                             .parsnip_eng = "all") {
  
  # Thank you https://stackoverflow.com/questions/74691333/build-a-tibble-of-parsnip-model-calls-with-match-fun/74691529#74691529
  # Tidyeval ----
  call <- list(.parsnip_fns) %>%
    purrr::flatten_chr()
  engine <- list(.parsnip_eng) %>%
    purrr::flatten_chr()
  
  # Make tibble
  mod_tbl <- tibble::tribble(
    ~.parsnip_engine, ~.parsnip_mode, ~.parsnip_fns,
    "lm", "regression", "linear_reg",
    "brulee", "regression", "linear_reg",
    "gee", "regression", "linear_reg",
    "glm","regression","linear_reg",
    "glmer","regression","linear_reg",
    "glmnet","regression","linear_reg",
    "gls","regression","linear_reg",
    "lme","regression","linear_reg",
    "lmer","regression","linear_reg",
    "stan","regression","linear_reg",
    "stan_glmer","regression","linear_reg",
    "Cubist","regression","cubist_rules",
    "glm","regression","poisson_reg",
    "gee","regression","poisson_reg",
    "glmer","regression","poisson_reg",
    "glmnet","regression","poisson_reg",
    "hurdle","regression","poisson_reg",
    "stan","regression","poisson_reg",
    "stan_glmer","regression","poisson_reg",
    "zeroinfl","regression","poisson_reg",
    "survival","censored regression","survival_reg",
    "flexsurv","censored regression","survival_reg",
    "slexsurvspline","censored regression","survival_reg",
    "earth","regression","bag_mars",
    "rpart","regression","bag_mars",
    "dbarts","regression","bart",
    "xgboost","regression","boost_tree",
    "lightgbm","regression","boost_tree",
    "mboost","censored regression","boost_tree",
    "rpart","regression","decision_tree",
    "partykit","regression","decision_tree",
    "mgcv","regression","gen_additive_mod",
    "earth","regression","mars",
    "nnet","regression","mlp",
    "brulee","regression","mlp",
    "kknn","regression","nearest_neighbor",
    "mixOmics","regression","pls",
    "ranger","regression","rand_forest",
    "randomForest","regression","rand_forest",
    "partykit","censored regression","rand_forest",
    "aorsf","censored regression","rand_forest",
    "xrf","regression","rule_fit",
    "LiblineaR","regression","svm_linear",
    "kernlab","regression","svm_linear",
    "kernlab","regression","svm_poly",
    "kernlab","regression","svm_rbf"
  )
  
  # Filter ----
  if (!"all" %in% engine){
    mod_tbl <- mod_tbl %>%
      dplyr::filter(.parsnip_engine %in% engine)
  }
  
  if (!"all" %in% call){
    mod_tbl <- mod_tbl %>%
      dplyr::filter(.parsnip_fns %in% call)
  }
  
  mod_filtered_tbl <- mod_tbl
  
  mod_spec_tbl <- mod_filtered_tbl %>%
    dplyr::mutate(
      model_spec = purrr::pmap(
        dplyr::cur_data(),
        ~ match.fun(..3)(mode = ..2, engine = ..1)
        #~ get(..3)(mode = ..2, engine = ..1)
      )
    ) %>%
    # add .model_id column
    dplyr::mutate(.model_id = dplyr::row_number()) %>%
    dplyr::select(.model_id, dplyr::everything())
  
  # Return ----
  class(mod_spec_tbl) <- c("fst_reg_spec_tbl", class(mod_spec_tbl))
  class(mod_spec_tbl) <- c("tidyaml_mod_spec_tbl", class(mod_spec_tbl))
  attr(mod_spec_tbl, ".parsnip_engines") <- .parsnip_eng
  attr(mod_spec_tbl, ".parsnip_functions") <- .parsnip_fns
  
  return(mod_spec_tbl)
  
}

Make `create_splits()` function

This function will create a splits object for any modeling issue by providing a character vector of the type you want along with some necessary arguments that will get passed to the appropriate rsample function and return the appropriate splits object.

tidy, glance, augment `brulee`

coef(brulee_fitted_wflw$fit$fit$fit) gets the coefficients

> coef(brulee_fitted_wflw$fit$fit$fit)
(Intercept)         cyl        disp          hp        drat          wt        qsec 
 0.53397143  0.46540201 -0.01057179  0.01011792  0.91103208 -1.81364763  1.05920923 
         vs          am        gear        carb 
 0.52524698  5.03626394  0.63133454 -1.23109210 
> coef(brulee_fitted_wflw$fit$fit$fit) %>% enframe()
# A tibble: 11 × 2
   name          value
   <chr>         <dbl>
 1 (Intercept)  0.534 
 2 cyl          0.465 
 3 disp        -0.0106
 4 hp           0.0101
 5 drat         0.911 
 6 wt          -1.81  
 7 qsec         1.06  
 8 vs           0.525 
 9 am           5.04  
10 gear         0.631 
11 carb        -1.23  

`fast_regression()` is producing inforrect workflow objects

The following objects are being produced incorrectly

This stems from using wflw[[1]] in the code and not mapping a function to the list column object.

To fix the wflw column make a helper function like so:

Function:

# Safely make workflow
internal_make_wflw <- function(.model_tbl, .rec_obj){
  
  # Tidyeval ----
  model_tbl <- .model_tbl
  rec_obj <- .rec_obj
  
  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }
  
  # Manipulation
  model_factor_tbl <- model_tbl %>%
    dplyr::mutate(.model_id = forcats::as_factor(.model_id)) %>%
    dplyr::mutate(rec_obj = list(rec_obj))
  
  # Make a group split object list
  models_list <- model_factor_tbl %>%
    dplyr::group_split(.model_id)
  
  # Make the Workflow Object using purrr imap
  wflw_list <- models_list %>%
    purrr::imap(
      .f = function(obj, id){
        
        # Pull the model column and then pluck the model
        mod <- obj %>% dplyr::pull(5) %>% purrr::pluck(1)
        
        # PUll the recipe column and then pluck the recipe
        rec_obj <- obj %>% dplyr::pull(6) %>% purrr::pluck(1)
        
        # Create a safe add_model function
        safe_add_model <- purrr::safely(
          workflows::add_model,
          otherwise = "Error - Could not make workflow object.",
          quiet = FALSE
        )
        
        # Return the workflow object with recipe and model
        ret <- workflows::workflow() %>%
          workflows::add_recipe(rec_obj) %>%
          safe_add_model(mod)
        
        # Pluck the result
        res <- ret %>% purrr::pluck("result")
        
        # Return the result
        return(res)
      }
    )
  
  
  # Return
  return(wflw_list)
}

Example:

library(tidyverse)
library(tidymodels)
tidymodels_prefer()

mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
  .parsnip_fns = "linear_reg", 
  .parsnip_eng = c("lm","glm")
)

# A tibble: 2 × 5
  .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec
      <int> <chr>           <chr>         <chr>        <list>    
1         1 lm              regression    linear_reg   <spec[+]> 
2         2 glm             regression    linear_reg   <spec[+]> 

# Generate Workflow object
mod_tbl <- mod_spec_tbl %>%
  dplyr::mutate(
    wflw = internal_make_wflw(mod_spec_tbl, .rec_obj = rec_obj)
  )

> mod_tbl
# A tibble: 2 × 6
  .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec wflw      
      <int> <chr>           <chr>         <chr>        <list>     <list>    
1         1 lm              regression    linear_reg   <spec[+]>  <workflow>
2         2 glm             regression    linear_reg   <spec[+]>  <workflow>

Make `create_model_spec()` function

The purpose

The purpose of this function is to allow a user to create a model specification of any parsnip type or multiples of different types. It can return either a list of the model specifications or a tibble both of which have the same information.

Function:

library(purrr)
library(parsnip)

create_model_spec <- function(.parsnip_engine = list("lm"), 
                              .mode = list("regression"),
                              .parsnip_fns = list("linear_reg"),
                              .return_tibble = TRUE) {
  
  # Tidyeval ----
  engine <- .parsnip_engine %>%
    purrr::flatten_chr() %>%
    as.list()
  mode <- .mode %>%
    purrr::flatten_chr() %>%
    as.list()
  call <- .parsnip_fns %>%
    purrr::flatten_chr() %>%
    as.list()
  ret_tibble <- as.logical(.return_tibble)

  # Make Model list for purrr call
  model_spec_list <- list(
    call,
    engine,
    mode
  )

  # Use purrr pmap to make mode specs
  models <- purrr::pmap(
    .l = model_spec_list,
    .f = function(call, engine, mode) {
      match.fun(call)(engine = engine, mode = mode)
    }
  )

  # Return ----
  models_list <- list(
    .parsnip_engine = engine,
    .parsnip_mode = mode,
    .parsnip_fns = call,
    .model_spec = models
  )

  ret_tbl <- dplyr::tibble(
    .parsnip_engine = unlist(engine),
    .parsnip_mode   = unlist(mode),
    .parsnip_fns    = unlist(call),
    .model_spec     = models
  )

  ifelse(ret_tibble, return(ret_tbl), return(models_list))
}

Examples:

> create_model_spec(
+   .engine = list("lm","glm","glmnet","cubist"), 
+   .parsnip_fns = list(
+     rep("linear_reg", 3),
+     "cubist_rules")
+   )
# A tibble: 4 × 4
  .parsnip_engine .parsnip_mode .parsnip_fns .model_spec
  <chr>           <chr>         <chr>        <list>     
1 lm              regression    linear_reg   <spec[+]>  
2 glm             regression    linear_reg   <spec[+]>  
3 glmnet          regression    linear_reg   <spec[+]>  
4 cubist          regression    cubist_rules <spec[+]>  
> model_spec_list <- create_model_spec(
+   .engine = list("lm","glm","glmnet","cubist"), 
+   .parsnip_fns = list(
+     rep("linear_reg", 3),
+     "cubist_rules"),
+   .return_tibble = FALSE
+ )
> model_spec_list$.model_spec
[[1]]
Linear Regression Model Specification (regression)

Computational engine: lm 


[[2]]
Linear Regression Model Specification (regression)

Computational engine: glm 


[[3]]
Linear Regression Model Specification (regression)

Computational engine: glmnet 


[[4]]
! parsnip could not locate an implementation for `cubist_rules` regression model
  specifications using the `cubist` engine.

Cubist Model Specification (regression)

Computational engine: cubist 

Create function `internal_make_spec_tbl()`

Function:

internal_make_spec_tbl <- function(.data){
  
  # Checks ----
  df <- dplyr::as_tibble(.data)
  
  nms <- unique(names(df))
  
  if (!".parsnip_engine" %in% nms | !".parsnip_mode" %in% nms | !".parsnip_fns" %in% nms){
    rlang::abort(
      message = "The model tibble must come from the class/reg to parsnip function.",
      use_cli_format = TRUE
    )
  }
  
  # Make tibble ----
  mod_spec_tbl <- df %>%
    dplyr::mutate(
      model_spec = purrr::pmap(
        dplyr::cur_data(),
        ~ match.fun(..3)(mode = ..2, engine = ..1)
      )
    ) %>%
    # add .model_id column
    dplyr::mutate(.model_id = dplyr::row_number()) %>%
    dplyr::select(.model_id, dplyr::everything())
  
  # Return ----
  return(mod_spec_tbl)
  
}

Example:

> internal_make_spec_tbl(mod_tbl)
# A tibble: 46 × 5
   .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec
       <int> <chr>           <chr>         <chr>        <list>    
 1         1 lm              regression    linear_reg   <spec[+]> 
 2         2 brulee          regression    linear_reg   <spec[+]> 
 3         3 gee             regression    linear_reg   <spec[+]> 
 4         4 glm             regression    linear_reg   <spec[+]> 
 5         5 glmer           regression    linear_reg   <spec[+]> 
 6         6 glmnet          regression    linear_reg   <spec[+]> 
 7         7 gls             regression    linear_reg   <spec[+]> 
 8         8 lme             regression    linear_reg   <spec[+]> 
 9         9 lmer            regression    linear_reg   <spec[+]> 
10        10 stan            regression    linear_reg   <spec[+]> 
# … with 36 more rows
# ℹ Use `print(n = ...)` to see more rows

update safely message to NULL

Update all safe functions to have a message of NULL this makes it easier to drop models that fail for the end user.

Function 1:

#' Internals Safely Make Workflow from Model Spec tibble
#'
#' @family Internals
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @description Safely Make a workflow from a model spec tibble.
#'
#' @details Create a model specification tibble that has a [workflows::workflow()]
#' list column.
#'
#' @param .model_tbl The model table that is generated from a function like
#' `fast_regression_parsnip_spec_tbl()`, must have a class of "tidyaml_mod_spec_tbl".
#' @param .rec_obj The recipe object that is going to be used to make the workflow
#' object.
#'
#' @examples
#' library(recipes, quietly = TRUE)
#' library(dplyr, quietly = TRUE)
#'
#' mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
#'   .parsnip_eng = c("lm","glm","gee"),
#'   .parsnip_fns = "linear_reg"
#' )
#'
#' rec_obj <- recipe(mpg ~ ., data = mtcars)
#'
#' internal_make_wflw(mod_spec_tbl, rec_obj)
#'
#' @return
#' A list object of workflows.
#'
#' @name internal_make_wflw
NULL

#' @export
#' @rdname internal_make_wflw

# Safely make workflow
internal_make_wflw <- function(.model_tbl, .rec_obj){
  
  # Tidyeval ----
  model_tbl <- .model_tbl
  rec_obj <- .rec_obj
  
  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }
  
  # Manipulation
  model_factor_tbl <- model_tbl %>%
    dplyr::mutate(.model_id = forcats::as_factor(.model_id)) %>%
    dplyr::mutate(rec_obj = list(rec_obj))
  
  # Make a group split object list
  models_list <- model_factor_tbl %>%
    dplyr::group_split(.model_id)
  
  # Make the Workflow Object using purrr imap
  wflw_list <- models_list %>%
    purrr::imap(
      .f = function(obj, id){
        
        # Pull the model column and then pluck the model
        mod <- obj %>% dplyr::pull(5) %>% purrr::pluck(1)
        
        # PUll the recipe column and then pluck the recipe
        rec_obj <- obj %>% dplyr::pull(6) %>% purrr::pluck(1)
        
        # Create a safe add_model function
        safe_add_model <- purrr::safely(
          workflows::add_model,
          otherwise = NULL,
          quiet = TRUE
        )
        
        # Return the workflow object with recipe and model
        ret <- workflows::workflow() %>%
          workflows::add_recipe(rec_obj) %>%
          safe_add_model(mod)
        
        # Pluck the result
        res <- ret %>% purrr::pluck("result")
        
        if (!is.null(ret$error)) message(stringr::str_glue("{ret$error}"))
        
        # Return the result
        return(res)
      }
    )
  
  
  # Return
  return(wflw_list)
}

Function 2:

#' Internals Safely Make a Fitted Workflow from Model Spec tibble
#'
#' @family Internals
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @description Safely Make a fitted workflow from a model spec tibble.
#'
#' @details Create a fitted `parnsip` model from a `workflow` object.
#'
#' @param .model_tbl The model table that is generated from a function like
#' `fast_regression_parsnip_spec_tbl()`, must have a class of "tidyaml_mod_spec_tbl".
#' This is meant to be used after the function `internal_make_wflw()` has been
#' run and the tibble has been saved.
#' @param .splits_obj The splits object from the auto_ml function. It is internal
#' to the `auto_ml_` function.
#'
#' @examples
#' library(recipes, quietly = TRUE)
#' library(dplyr, quietly = TRUE)
#'
#' mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
#'   .parsnip_eng = c("lm","glm","gee"),
#'   .parsnip_fns = "linear_reg"
#' )
#'
#' rec_obj <- recipe(mpg ~ ., data = mtcars)
#' splits_obj <- create_splits(mtcars, "initial_split")
#'
#' mod_tbl <- mod_spec_tbl %>%
#'   mutate(wflw = internal_make_wflw(mod_spec_tbl, rec_obj))
#'
#' internal_make_fitted_wflw(mod_tbl, splits_obj)
#'
#' @return
#' A list object of workflows.
#'
#' @name internal_make_fitted_wflw
NULL

#' @export
#' @rdname internal_make_fitted_wflw

# Safely make fitted workflow
internal_make_fitted_wflw <- function(.model_tbl, .splits_obj){
  
  # Tidyeval ----
  model_tbl <- .model_tbl
  splits_obj <- .splits_obj
  col_nms <- colnames(model_tbl)
  
  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }
  
  if (!"wflw" %in% col_nms){
    rlang::abort(
      message = "Missing the column 'wflw'",
      use_cli_format = TRUE
    )
  }
  
  if (!".model_id" %in% col_nms){
    rlang::abort(
      message = "Missing the column '.model_id'",
      use_cli_format = TRUE
    )
  }
  
  # Manipulation
  # Make a group split object list
  models_list <- model_tbl %>%
    dplyr::group_split(.model_id)
  
  # Make the fitted workflow object using purrr imap
  fitted_wflw_list <- models_list %>%
    purrr::imap(
      .f = function(obj, id){
        
        # Pull the workflow column and then pluck it
        wflw <- obj %>% dplyr::pull(6) %>% purrr::pluck(1)
        
        # Create a safe parsnip::fit function
        safe_parsnip_fit <- purrr::safely(
          parsnip::fit,
          otherwise = NULL,
          quiet = TRUE
        )
        
        # Return the fitted workflow
        ret <- safe_parsnip_fit(
          wflw, data = rsample::training(splits_obj$splits)
        )
        
        res <- ret %>% purrr::pluck("result")
        
        if (!is.null(ret$error)) message(stringr::str_glue("{ret$error}"))
        
        return(res)
      }
    )
  
  return(fitted_wflw_list)
  
}

Function 3:

#' Internals Safely Make Predictions on a Fitted Workflow from Model Spec tibble
#'
#' @family Internals
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @description Safely Make predictions on a fitted workflow from a model spec tibble.
#'
#' @details Create predictions on a fitted `parnsip` model from a `workflow` object.
#'
#' @param .model_tbl The model table that is generated from a function like
#' `fast_regression_parsnip_spec_tbl()`, must have a class of "tidyaml_mod_spec_tbl".
#' This is meant to be used after the function `internal_make_fitted_wflw()` has been
#' run and the tibble has been saved.
#' @param .splits_obj The splits object from the auto_ml function. It is internal
#' to the `auto_ml_` function.
#'
#' @examples
#' library(recipes, quietly = TRUE)
#' library(dplyr, quietly = TRUE)
#'
#' mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
#'   .parsnip_eng = c("lm","glm","gee"),
#'   .parsnip_fns = "linear_reg"
#' )
#'
#' rec_obj <- recipe(mpg ~ ., data = mtcars)
#' splits_obj <- create_splits(mtcars, "initial_split")
#'
#' mod_tbl <- mod_spec_tbl %>%
#'   mutate(wflw = internal_make_wflw(mod_spec_tbl, rec_obj))
#'
#' mod_fitted_tbl <- mod_tbl %>%
#'   mutate(fitted_wflw = internal_make_fitted_wflw(mod_tbl, splits_obj))
#'
#' internal_make_wflw_predictions(mod_fitted_tbl, splits_obj)
#'
#' @return
#' A list object of workflows.
#'
#' @name internal_make_wflw_predictions
NULL

#' @export
#' @rdname internal_make_wflw_predictions

# Safely make predictions on fitted workflow
internal_make_wflw_predictions <- function(.model_tbl, .splits_obj){
  
  # Tidyeval ----
  model_tbl <- .model_tbl
  splits_obj <- .splits_obj
  col_nms <- colnames(model_tbl)
  
  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }
  
  if (!"fitted_wflw" %in% col_nms){
    rlang::abort(
      message = "Missing the column 'wflw'",
      use_cli_format = TRUE
    )
  }
  
  if (!".model_id" %in% col_nms){
    rlang::abort(
      message = "Missing the column '.model_id'",
      use_cli_format = TRUE
    )
  }
  
  # Manipulation
  # Make a group split object list
  model_factor_tbl <- model_tbl %>%
    dplyr::mutate(.model_id = forcats::as_factor(.model_id))
  
  models_list <- model_factor_tbl %>%
    dplyr::group_split(.model_id)
  
  # Make the predictions on the fitted workflow object using purrr imap
  wflw_preds_list <- models_list %>%
    purrr::imap(
      .f = function(obj, id){
        
        # Pull the fitted workflow column and then pluck it
        fitted_wflw = obj %>% dplyr::pull(7) %>% purrr::pluck(1)
        
        # Create a safe stats::predict
        safe_stats_predict <- purrr::safely(
          stats::predict,
          otherwise = NULL,
          quiet = TRUE
        )
        
        # Return the predictions
        ret <- safe_stats_predict(
          fitted_wflw,
          new_data = rsample::training(splits_obj$splits)
        )
        
        res <- ret %>% purrr::pluck("result")
        
        if (!is.null(ret$error)) message(stringr::str_glue("{ret$error}"))
        
        return(res)
      }
    )
  
  return(wflw_preds_list)
}

Make function `fast_classification_parsnip_spec_tbl()`

Function:

#' Utility Classification call to `parsnip`
#'
#' @family Utility
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @details Creates a tibble of parsnip classification model specifications. This will
#' create a tibble of 32 different classification model specifications which can be
#' filtered. The model specs are created first and then filtered out. This will
#' only create models for __classification__ problems. To find all of the supported
#' models in this package you can visit \url{https://www.tidymodels.org/find/parsnip/}
#'
#' @description Creates a tibble of parsnip classification model specifications.
#'
#' @param .parsnip_fns The default for this is set to `all`. This means that all
#' of the parsnip __classification__ functions will be used, for example `bag_mars()`,
#' or `bart()`. You can also choose to pass a c() vector like `c("barg_mars","bart")`
#' @param .parsnip_eng The default for this is set to `all`. This means that all
#' of the parsnip __classification engines__ will be used, for example `earth`, or
#' `dbarts`. You can also choose to pass a c() vector like `c('earth', 'dbarts')`
#'
#' @examples
#' fast_classification_parsnip_spec_tbl(.parsnip_fns = "logistic_reg")
#' fast_classification_parsnip_spec_tbl(.parsnip_eng = c("earth","dbarts"))
#'
#' @return
#' A tibble with an added class of 'fst_class_spec_tbl'
#'
#' @importFrom parsnip linear_reg cubist_rules poisson_reg survival_reg
#'
#' @name fast_classification_parsnip_spec_tbl
NULL

#' @export
#' @rdname fast_classification_parsnip_spec_tbl

fast_classification_parsnip_spec_tbl <- function(.parsnip_fns = "all",
                                             .parsnip_eng = "all") {
  
  # Thank you https://stackoverflow.com/questions/74691333/build-a-tibble-of-parsnip-model-calls-with-match-fun/74691529#74691529
  # Tidyeval ----
  call <- list(.parsnip_fns) %>%
    purrr::flatten_chr()
  engine <- list(.parsnip_eng) %>%
    purrr::flatten_chr()
  
  # Make tibble
  mod_tbl <- tibble::tribble(
    ~.parsnip_engine, ~.parsnip_mode, ~.parsnip_fns,
    "earth","classification","bag_mars",
    "earth","classification","discrim_flexible",
    "dbarts","classification","bart",
    "MASS","classification","discrim_linear",
    "mda","classification","discrim_linear",
    "sda","classification","discrim_linear",
    "sparsediscrim","classification","discrim_linear",
    "MASS","classification","discrim_quad",
    "sparsediscrim","classification","discrim_quad",
    "klaR","classification","discrim_regularized",
    "mgcv","classification","gen_additive_mod",
    "brulee","classification","logistic_reg",
    "gee","classification","logistic_reg",
    "glm","classification","logistic_reg",
    "glmer","classification","logistic_reg",
    "glmnet","classification","logistic_reg",
    "LiblineaR","classification","logistic_reg",
    "earth","classification","mars",
    "brulee","classification","mlp",
    "nnet","classification","mlp",
    "brulee","classification","multinom_reg",
    "glmnet","classification","multinom_reg",
    "nnet","classification","multinom_reg",
    "klaR","classification","naive_Bayes",
    "kknn","classification","nearest_neighbor",
    "mixOmics","classification","pls",
    "xrf","classification","rule_fit",
    "kernlab","classification","svm_linear",
    "LiblineaR","classification","svm_linear",
    "kernlab","classification","svm_poly",
    "kernlab","classification","svm_rbf",
    "liquidSVM","classification","svm_rbf"
  )
  
  # Filter ----
  if (!"all" %in% engine){
    mod_tbl <- mod_tbl %>%
      dplyr::filter(.parsnip_engine %in% engine)
  }
  
  if (!"all" %in% call){
    mod_tbl <- mod_tbl %>%
      dplyr::filter(.parsnip_fns %in% call)
  }
  
  mod_filtered_tbl <- mod_tbl
  
  mod_spec_tbl <- mod_filtered_tbl %>%
    dplyr::mutate(
      model_spec = purrr::pmap(
        dplyr::cur_data(),
        ~ match.fun(..3)(mode = ..2, engine = ..1)
        #~ get(..3)(mode = ..2, engine = ..1)
      )
    ) %>%
    # add .model_id column
    dplyr::mutate(.model_id = dplyr::row_number()) %>%
    dplyr::select(.model_id, dplyr::everything())
  
  # Return ----
  class(mod_spec_tbl) <- c("fst_class_spec_tbl", class(mod_spec_tbl))
  class(mod_spec_tbl) <- c("tidyaml_mod_spec_tbl", class(mod_spec_tbl))
  attr(mod_spec_tbl, ".parsnip_engines") <- .parsnip_eng
  attr(mod_spec_tbl, ".parsnip_functions") <- .parsnip_fns
  
  return(mod_spec_tbl)
  
}

Examples:

> fast_classification_parsnip_spec_tbl(.parsnip_fns = "logistic_reg")
# A tibble: 6 × 5
  .model_id .parsnip_engine .parsnip_mode  .parsnip_fns model_spec
      <int> <chr>           <chr>          <chr>        <list>    
1         1 brulee          classification logistic_reg <spec[+]> 
2         2 gee             classification logistic_reg <spec[+]> 
3         3 glm             classification logistic_reg <spec[+]> 
4         4 glmer           classification logistic_reg <spec[+]> 
5         5 glmnet          classification logistic_reg <spec[+]> 
6         6 LiblineaR       classification logistic_reg <spec[+]> 
> fast_classification_parsnip_spec_tbl(.parsnip_eng = c("earth","dbarts"))
# A tibble: 4 × 5
  .model_id .parsnip_engine .parsnip_mode  .parsnip_fns     model_spec
      <int> <chr>           <chr>          <chr>            <list>    
1         1 earth           classification bag_mars         <spec[+]> 
2         2 earth           classification discrim_flexible <spec[+]> 
3         3 dbarts          classification bart             <spec[+]> 
4         4 earth           classification mars             <spec[+]> 

Bugs Identified in `fast_regression_parsnip_spec_tbl()`

Updated code to 58 models dropping C5.0 as it's only for classification.

Function:

#' Utility Regression call to `parsnip`
#'
#' @family Utility
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @details Creates a tibble of parsnip regression model specifications. This will
#' create a tibble of 58 different regression model specifications which can be
#' filtered. The model specs are created first and then filtered out. This will
#' only create models for __regression__ problems. To find all of the supported
#' models in this package you can visit \url{https://www.tidymodels.org/find/parsnip/}
#'
#' @seealso \url{https://parsnip.tidymodels.org/reference/linear_reg.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/cubist_rules.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/survival_reg.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/poisson_reg.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/bag_mars.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/bag_tree.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/bart.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/boost_tree.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/decision_tree.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/gen_additive_mod.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/mars.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/mlp.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/nearest_neighbor.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/pls.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/rand_forest.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/rule_fit.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/svm_linear.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/svm_poly.html}
#' @seealso \url{https://parsnip.tidymodels.org/reference/svm_rbf.html}
#'
#' @description Creates a tibble of parsnip regression model specifications.
#'
#' @param .parsnip_fns The default for this is set to `all`. This means that all
#' of the parsnip __linear regression__ functions will be used, for example `linear_reg()`,
#' or `cubist_rules`. You can also choose to pass a c() vector like `c("linear_reg","cubist_rules")`
#' @param .parsnip_eng The default for this is set to `all`. This means that all
#' of the parsnip __linear regression engines__ will be used, for example `lm`, or
#' `glm`. You can also choose to pass a c() vector like `c('lm', 'glm')`
#'
#' @examples
#' fast_regression_parsnip_spec_tbl(.parsnip_fns = "linear_reg")
#' fast_regression_parsnip_spec_tbl(.parsnip_eng = c("lm","glm"))
#'
#' @return
#' A tibble with an added class of 'fst_reg_spec_tbl'
#'
#' @importFrom parsnip linear_reg cubist_rules poisson_reg survival_reg
#'
#' @name fast_regression_parsnip_spec_tbl
NULL

#' @export
#' @rdname fast_regression_parsnip_spec_tbl

fast_regression_parsnip_spec_tbl <- function(.parsnip_fns = "all",
                                             .parsnip_eng = "all") {
  
  # Thank you https://stackoverflow.com/questions/74691333/build-a-tibble-of-parsnip-model-calls-with-match-fun/74691529#74691529
  # Tidyeval ----
  call <- list(.parsnip_fns) %>%
    purrr::flatten_chr()
  engine <- list(.parsnip_eng) %>%
    purrr::flatten_chr()
  
  # Make tibble
  mod_tbl <- dplyr::tibble(
    .parsnip_engine = c(
      # linear_reg
      "lm", "brulee", "gee", "glm", "glmer", "glmnet", "gls", "h2o", "keras",
      "lme", "lmer", "spark", "stan", "stan_glmer",
      # cubist_rules
      "Cubist",
      # possion_reg
      "glm", "gee", "glmer", "glmnet", "h2o", "hurdle", "stan", "stan_glmer",
      "zeroinfl",
      # survival_reg
      "survival", "flexsurv", "flexsurvspline",
      # bag_mars
      "earth",
      # bag_tree
      "rpart",
      # bart
      "dbarts",
      # boost_tree
      "xgboost","h2o","lightgbm","spark","mboost",
      # decision_tree
      "rpart","spark","partykit",
      # gen_additive_mod
      "mgcv",
      # mars
      "earth",
      # mlp
      "nnet","brulee","h2o","keras",
      # nearest_neighbor
      "kknn",
      # pls
      "mixOmics",
      # rand_forest
      "ranger","h2o","randomForest","spark","partykit","aorsf",
      # rule_fit
      "xrf","h2o",
      # svm_linear
      "LiblineaR","kernlab",
      # svm_poly
      "kernlab",
      # svm_rbf
      "kernlab"
    ),
    .parsnip_mode = c(
      # linear_reg
      rep("regression", 14),
      # cubist_rules
      "regression",
      # poisson_reg
      rep("regression", 9),
      # survival_reg
      rep("censored regression", 3),
      # bag_mars
      "regression",
      # bag_tree
      "regression",
      # bart
      "regression",
      # boost_tree
      rep("regression", 4),
      "censored regression",
      # decision_tree
      rep("regression", 3),
      # gen_additive_mod
      "regression",
      # mars
      "regression",
      # mlp
      rep("regression", 4),
      # nearest_neighbor
      "regression",
      # pls
      "regression",
      # rand_forest
      rep("regression", 4),
      rep("censored regression", 2),
      # rule_fit
      rep("regression", 2),
      # svm_linear
      rep("regression", 2),
      # svm_poly
      "regression",
      # svm_rbf
      "regression"
    ),
    .parsnip_fns = c(
      rep("linear_reg", 14),
      "cubist_rules",
      rep("poisson_reg",9),
      rep("survival_reg", 3),
      "bag_mars",
      rep("bag_tree",1),
      "bart",
      rep("boost_tree",5),
      rep("decision_tree", 3),
      "gen_additive_mod",
      "mars",
      rep("mlp", 4),
      "nearest_neighbor",
      "pls",
      rep("rand_forest",6),
      rep("rule_fit",2),
      rep("svm_linear", 2),
      "svm_poly",
      "svm_rbf"
    )
  )
  
  # Filter ----
  if (!"all" %in% engine){
    mod_tbl <- mod_tbl %>%
      dplyr::filter(.parsnip_engine %in% engine)
  }
  
  if (!"all" %in% call){
    mod_tbl <- mod_tbl %>%
      dplyr::filter(.parsnip_fns %in% call)
  }
  
  mod_filtered_tbl <- mod_tbl
  
  mod_spec_tbl <- mod_filtered_tbl %>%
    dplyr::mutate(
      model_spec = purrr::pmap(
        dplyr::cur_data(),
        ~ match.fun(..3)(mode = ..2, engine = ..1)
        #~ get(..3)(mode = ..2, engine = ..1)
      )
    )
  
  # Return ----
  class(mod_spec_tbl) <- c("fst_reg_spec_tbl", class(mod_spec_tbl))
  attr(mod_spec_tbl, ".parsnip_engines") <- .parsnip_eng
  attr(mod_spec_tbl, ".parsnip_functions") <- .parsnip_fns
  
  return(mod_spec_tbl)
  
}

Example:

> df <- mtcars
> rec_obj <- recipe(mpg ~ ., data = df)
> fast_regression(
+   .data = df,
+   .rec_obj = rec_obj,
+   .split_type = "initial_split"
+ )
# A tibble: 58 × 8
   .parsnip_engine .pars…¹ .pars…² model_s…³ model_…⁴ wflw       fitted_w…⁵ pred_w…⁶
   <chr>           <chr>   <chr>   <list>    <list>   <list>     <list>     <list>  
 1 lm              regreslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
 2 brulee          regreslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
 3 gee             regreslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
 4 glm             regreslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
 5 glmer           regreslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
 6 glmnet          regreslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
 7 gls             regreslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
 8 h2o             regreslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
 9 keras           regreslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
10 lme             regreslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
# … with 48 more rows, and abbreviated variable names ¹​.parsnip_mode,
#   ²​.parsnip_fns, ³​model_spec, ⁴​model_recipe, ⁵​fitted_wflw, ⁶​pred_wflw
# ℹ Use `print(n = ...)` to see more rows

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.