From 70471838eb27bda0b30f4f13500547b4093c1d23 Mon Sep 17 00:00:00 2001 From: bgreenwell Date: Sun, 16 Jul 2023 21:38:59 -0400 Subject: [PATCH 1/4] fix CITATION and other tweaks for CRAN --- R/get_training_data.R | 7 +++---- README.Rmd | 4 ++-- README.md | 4 ++-- inst/CITATION | 4 ++-- inst/references.bib | 4 ++-- vignettes/vip.Rmd | 4 ++-- vignettes/vip.Rmd.orig | 3 ++- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/R/get_training_data.R b/R/get_training_data.R index 0592f18..7c13819 100644 --- a/R/get_training_data.R +++ b/R/get_training_data.R @@ -6,7 +6,6 @@ msg <- paste0( #' @keywords internal -#' @noRd get_training_data <- function(object) { UseMethod("get_training_data") } @@ -134,19 +133,19 @@ get_training_data.earth <- function(object) { # Package: h2o ----------------------------------------------------------------- #' @keywords internal -get_training_data.H2OBinomialModel <- function(object, ...) { +get_training_data.H2OBinomialModel <- function(object) { as.data.frame(h2o::h2o.getFrame(object@allparameters$training_frame)) } #' @keywords internal -get_training_data.H2OMultinomialModel <- function(object, ...) { +get_training_data.H2OMultinomialModel <- function(object) { as.data.frame(h2o::h2o.getFrame(object@allparameters$training_frame)) } #' @keywords internal -get_training_data.H2ORegressionModel <- function(object, ...) { +get_training_data.H2ORegressionModel <- function(object) { as.data.frame(h2o::h2o.getFrame(object@allparameters$training_frame)) } diff --git a/README.Rmd b/README.Rmd index f42e589..6b15a32 100644 --- a/README.Rmd +++ b/README.Rmd @@ -14,11 +14,11 @@ knitr::opts_chunk$set( # vip: Variable Importance Plots -[![CRAN_Status_Badge](http://www.r-pkg.org/badges/version/vip)](https://cran.r-project.org/package=vip) +[![CRAN_Status_Badge](https://www.r-pkg.org/badges/version/vip)](https://cran.r-project.org/package=vip) [![R-CMD-check](https://github.com/koalaverse/vip/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/koalaverse/vip/actions/workflows/R-CMD-check.yaml) [![Coverage Status](https://codecov.io/gh/koalaverse/vip/graph/badge.svg)](https://app.codecov.io/github/koalaverse/vip?branch=master) [![Lifecycle: stable](https://img.shields.io/badge/lifecycle-stable-brightgreen.svg)](https://lifecycle.r-lib.org/articles/stages.html#stable) -[![Downloads](http://cranlogs.r-pkg.org/badges/vip)](http://cran.r-project.org/package=vip/) +[![Downloads](https://cranlogs.r-pkg.org/badges/vip)](https://cran.r-project.org/package=vip/) [![The R Journal](https://img.shields.io/badge/The%20R%20Journal-10.32614%2FRJ--2020--013-brightgreen)](https://doi.org/10.32614/RJ-2020-013) diff --git a/README.md b/README.md index cb064c9..ca85523 100644 --- a/README.md +++ b/README.md @@ -3,13 +3,13 @@ -[![CRAN_Status_Badge](http://www.r-pkg.org/badges/version/vip)](https://cran.r-project.org/package=vip) +[![CRAN_Status_Badge](https://www.r-pkg.org/badges/version/vip)](https://cran.r-project.org/package=vip) [![R-CMD-check](https://github.com/koalaverse/vip/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/koalaverse/vip/actions/workflows/R-CMD-check.yaml) [![Coverage Status](https://codecov.io/gh/koalaverse/vip/graph/badge.svg)](https://app.codecov.io/github/koalaverse/vip?branch=master) [![Lifecycle: stable](https://img.shields.io/badge/lifecycle-stable-brightgreen.svg)](https://lifecycle.r-lib.org/articles/stages.html#stable) -[![Downloads](http://cranlogs.r-pkg.org/badges/vip)](http://cran.r-project.org/package=vip/) +[![Downloads](https://cranlogs.r-pkg.org/badges/vip)](https://cran.r-project.org/package=vip/) [![The R Journal](https://img.shields.io/badge/The%20R%20Journal-10.32614%2FRJ--2020--013-brightgreen)](https://doi.org/10.32614/RJ-2020-013) diff --git a/inst/CITATION b/inst/CITATION index e1f1c0c..12cd558 100644 --- a/inst/CITATION +++ b/inst/CITATION @@ -1,8 +1,8 @@ citHeader("To cite vip in publications use:") -citEntry(entry = "Article", +bibentry(entry = "Article", title = "Variable Importance Plots---An Introduction to the vip Package", - author = personList(as.person("Brandon M. Greenwell"), as.person("Bradley C. Boehmke")), + author = c(as.person("Brandon M. Greenwell"), as.person("Bradley C. Boehmke")), journal = "The R Journal", year = "2020", volume = "12", diff --git a/inst/references.bib b/inst/references.bib index 55471af..f0a3438 100644 --- a/inst/references.bib +++ b/inst/references.bib @@ -492,7 +492,7 @@ @Article{party2007a year = {2007}, volume = {8}, number = {25}, - url = {https://www.biomedcentral.com/1471-2105/8/25}, + url = {https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-8-25}, } @Article{party2008b, title = {Conditional Variable Importance for Random Forests}, @@ -501,7 +501,7 @@ @Article{party2008b year = {2008}, volume = {9}, number = {307}, - url = {https://www.biomedcentral.com/1471-2105/9/307}, + url = {https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-9-307}, } @Article{partykit2015, title = {{partykit}: A Modular Toolkit for Recursive Partytioning in {R}}, diff --git a/vignettes/vip.Rmd b/vignettes/vip.Rmd index 1877361..308d567 100644 --- a/vignettes/vip.Rmd +++ b/vignettes/vip.Rmd @@ -14,7 +14,7 @@ vignette: > %\VignetteIndexEntry{vip} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} -bibliography: '`r system.file("references.bib", package = "vip")`' +bibliography: '/Users/bgreenwell/Dropbox/devel/vip/inst/references.bib' --- @@ -808,7 +808,7 @@ Much better (and just the negative of the previous results, as expected)! For a In this section, we compare the performance of four implementations of permutation-based VI scores: `iml::FeatureImp()` (version 0.11.1), `ingredients::feature_importance()` (version 2.3.0), `mmpf::permutationImportance` (version 0.0.5), and `vip::vi()` (version 0.4.0). -We simulated 10,000 training observations from the Friedman 1 benchmark problem and trained a random forest using the [ranger](https://cran.r-project.org/package=ranger) package. For each implementation, we computed permutation-based VI scores 100 times using the [microbenchmark](https://cran.r-project.org/package=microbenchmark) package [@R-microbenchmark]. For this benchmark we did not use any of the parallel processing capability available in the [iml](https://cran.r-project.org/package=iml) and [vip](https://cran.r-project.org/package=vip) implementations. The results from [microbenchmark](https://cran.r-project.org/package=microbenchmark) are displayed in Figure \ref@(fig:benchmark) and summarized in the output below. In this case, the [vip](https://cran.r-project.org/package=vip) package (version 0.4.0) was the fastest, followed closely by [ingredients](https://cran.r-project.org/package=ingredients) and [mmpf](https://cran.r-project.org/package=mmpf). It should be noted, however, that the implementations in [vip](https://cran.r-project.org/package=vip) and [iml](https://cran.r-project.org/package=iml) can be parallelized. To the best of our knowledge, this is not the case for [ingredients](https://cran.r-project.org/package=ingredients) or [mmpf](https://cran.r-project.org/package=mmpf) (although it would not be difficult to write a simple parallel wrapper for either). The code used to generate these benchmarks can be found at https://bit.ly/2TogXrq. +We simulated 10,000 training observations from the Friedman 1 benchmark problem and trained a random forest using the [ranger](https://cran.r-project.org/package=ranger) package. For each implementation, we computed permutation-based VI scores 100 times using the [microbenchmark](https://cran.r-project.org/package=microbenchmark) package [@R-microbenchmark]. For this benchmark we did not use any of the parallel processing capability available in the [iml](https://cran.r-project.org/package=iml) and [vip](https://cran.r-project.org/package=vip) implementations. The results from [microbenchmark](https://cran.r-project.org/package=microbenchmark) are displayed in Figure \ref@(fig:benchmark) and summarized in the output below. In this case, the [vip](https://cran.r-project.org/package=vip) package (version 0.4.0) was the fastest, followed closely by [ingredients](https://cran.r-project.org/package=ingredients) and [mmpf](https://cran.r-project.org/package=mmpf). It should be noted, however, that the implementations in [vip](https://cran.r-project.org/package=vip) and [iml](https://cran.r-project.org/package=iml) can be parallelized. To the best of our knowledge, this is not the case for [ingredients](https://cran.r-project.org/package=ingredients) or [mmpf](https://cran.r-project.org/package=mmpf) (although it would not be difficult to write a simple parallel wrapper for either). The code used to generate these benchmarks can be found at https://github.com/koalaverse/vip/blob/master/slowtests/slowtests-benchmarks.R. Violin plots comparing the computation time from three different implementations of permutation-based VI scores across 100 simulations. diff --git a/vignettes/vip.Rmd.orig b/vignettes/vip.Rmd.orig index 9187c06..cb21782 100644 --- a/vignettes/vip.Rmd.orig +++ b/vignettes/vip.Rmd.orig @@ -34,6 +34,7 @@ knitr::opts_chunk$set( # Execute the code from the vignette # knitr::knit("vignettes/vip.Rmd.orig", output = "vignettes/vip.Rmd") +# And don't forget to replace "man/" with "../man/" in `vip.Rmd` file ``` This vignette is essentially an up-to-date version of @RJ-2020-013. Please use that if you'd like to cite our work. @@ -582,7 +583,7 @@ Much better (and just the negative of the previous results, as expected)! For a In this section, we compare the performance of four implementations of permutation-based VI scores: `iml::FeatureImp()` (version `r packageVersion("iml")`), `ingredients::feature_importance()` (version `r packageVersion("ingredients")`), `mmpf::permutationImportance` (version `r packageVersion("mmpf")`), and `vip::vi()` (version `r packageVersion("vip")`). -We simulated 10,000 training observations from the Friedman 1 benchmark problem and trained a random forest using the [ranger](https://cran.r-project.org/package=ranger) package. For each implementation, we computed permutation-based VI scores 100 times using the [microbenchmark](https://cran.r-project.org/package=microbenchmark) package [@R-microbenchmark]. For this benchmark we did not use any of the parallel processing capability available in the [iml](https://cran.r-project.org/package=iml) and [vip](https://cran.r-project.org/package=vip) implementations. The results from [microbenchmark](https://cran.r-project.org/package=microbenchmark) are displayed in Figure \ref@(fig:benchmark) and summarized in the output below. In this case, the [vip](https://cran.r-project.org/package=vip) package (version `r packageVersion("vip")`) was the fastest, followed closely by [ingredients](https://cran.r-project.org/package=ingredients) and [mmpf](https://cran.r-project.org/package=mmpf). It should be noted, however, that the implementations in [vip](https://cran.r-project.org/package=vip) and [iml](https://cran.r-project.org/package=iml) can be parallelized. To the best of our knowledge, this is not the case for [ingredients](https://cran.r-project.org/package=ingredients) or [mmpf](https://cran.r-project.org/package=mmpf) (although it would not be difficult to write a simple parallel wrapper for either). The code used to generate these benchmarks can be found at https://bit.ly/2TogXrq. +We simulated 10,000 training observations from the Friedman 1 benchmark problem and trained a random forest using the [ranger](https://cran.r-project.org/package=ranger) package. For each implementation, we computed permutation-based VI scores 100 times using the [microbenchmark](https://cran.r-project.org/package=microbenchmark) package [@R-microbenchmark]. For this benchmark we did not use any of the parallel processing capability available in the [iml](https://cran.r-project.org/package=iml) and [vip](https://cran.r-project.org/package=vip) implementations. The results from [microbenchmark](https://cran.r-project.org/package=microbenchmark) are displayed in Figure \ref@(fig:benchmark) and summarized in the output below. In this case, the [vip](https://cran.r-project.org/package=vip) package (version `r packageVersion("vip")`) was the fastest, followed closely by [ingredients](https://cran.r-project.org/package=ingredients) and [mmpf](https://cran.r-project.org/package=mmpf). It should be noted, however, that the implementations in [vip](https://cran.r-project.org/package=vip) and [iml](https://cran.r-project.org/package=iml) can be parallelized. To the best of our knowledge, this is not the case for [ingredients](https://cran.r-project.org/package=ingredients) or [mmpf](https://cran.r-project.org/package=mmpf) (although it would not be difficult to write a simple parallel wrapper for either). The code used to generate these benchmarks can be found at https://github.com/koalaverse/vip/blob/master/slowtests/slowtests-benchmarks.R. ```{r benchmark, echo=FALSE, fig.cap="Violin plots comparing the computation time from three different implementations of permutation-based VI scores across 100 simulations."} # Load required packages From d1d6885bfa3d795a77a52ee51b7f7410875accd6 Mon Sep 17 00:00:00 2001 From: bgreenwell Date: Sun, 16 Jul 2023 21:56:03 -0400 Subject: [PATCH 2/4] another fix to CITATION --- inst/CITATION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inst/CITATION b/inst/CITATION index 12cd558..b73e316 100644 --- a/inst/CITATION +++ b/inst/CITATION @@ -1,6 +1,6 @@ citHeader("To cite vip in publications use:") -bibentry(entry = "Article", +bibentry("Article", title = "Variable Importance Plots---An Introduction to the vip Package", author = c(as.person("Brandon M. Greenwell"), as.person("Bradley C. Boehmke")), journal = "The R Journal", From 9fba77eae4c15f00625e3d14d51d08774b886998 Mon Sep 17 00:00:00 2001 From: bgreenwell Date: Sun, 16 Jul 2023 22:05:27 -0400 Subject: [PATCH 3/4] fix vignette bib entry...again --- vignettes/vip.Rmd | 2 +- vignettes/vip.Rmd.orig | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vignettes/vip.Rmd b/vignettes/vip.Rmd index 308d567..8bf9dc0 100644 --- a/vignettes/vip.Rmd +++ b/vignettes/vip.Rmd @@ -14,7 +14,7 @@ vignette: > %\VignetteIndexEntry{vip} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} -bibliography: '/Users/bgreenwell/Dropbox/devel/vip/inst/references.bib' +bibliography: '`r system.file("references.bib", package = "vip")`' --- diff --git a/vignettes/vip.Rmd.orig b/vignettes/vip.Rmd.orig index cb21782..9c5f0b7 100644 --- a/vignettes/vip.Rmd.orig +++ b/vignettes/vip.Rmd.orig @@ -34,7 +34,12 @@ knitr::opts_chunk$set( # Execute the code from the vignette # knitr::knit("vignettes/vip.Rmd.orig", output = "vignettes/vip.Rmd") -# And don't forget to replace "man/" with "../man/" in `vip.Rmd` file +# +# TODO: +# * Don't forget to replace "man/" with "../man/" in `vip.Rmd` file. +# * Don't forget to manually add +# '`r system.file("references.bib", package = "vip")`' to the `bibliography` +# field in `vip.Rmd`. ``` This vignette is essentially an up-to-date version of @RJ-2020-013. Please use that if you'd like to cite our work. From 05ede557d1968fa91357bbdec9307bc4b9935370 Mon Sep 17 00:00:00 2001 From: bgreenwell Date: Sun, 16 Jul 2023 22:35:24 -0400 Subject: [PATCH 4/4] tweak tests --- NEWS.md | 2 + R/get_training_data.R | 166 +++++++++----------------- R/vi.R | 4 +- R/vi_permute.R | 9 +- inst/tinytest/test_pkg_C50.R | 10 -- inst/tinytest/test_pkg_Cubist.R | 6 - inst/tinytest/test_pkg_earth.R | 6 - inst/tinytest/test_pkg_glmnet.R | 6 - inst/tinytest/test_pkg_neuralnet.R | 6 - inst/tinytest/test_pkg_nnet.R | 6 - inst/tinytest/test_pkg_randomForest.R | 10 -- inst/tinytest/test_pkg_ranger.R | 6 - inst/tinytest/test_pkg_rpart.R | 6 - inst/tinytest/test_pkg_stats.R | 6 - inst/tinytest/test_vi_shap.R | 6 +- man/vi.Rd | 4 +- man/vi_permute.Rd | 9 +- 17 files changed, 74 insertions(+), 194 deletions(-) diff --git a/NEWS.md b/NEWS.md index 637a3c2..7bcfa94 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,8 @@ * This NEWS file now follows the [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) format. +* The training data has to be explicitly passed in more cases. + * Raised R version dependency to >= 4.1.0 (the introduction of the native piper operator `|>`). * The `vi_permute` function now uses [yardstick](https://cran.r-project.org/package=yardstick); consequently, metric functions now conform to [yardstick](https://cran.r-project.org/package=yardstick) metric argument names. diff --git a/R/get_training_data.R b/R/get_training_data.R index 7c13819..a2818ce 100644 --- a/R/get_training_data.R +++ b/R/get_training_data.R @@ -12,58 +12,60 @@ get_training_data <- function(object) { #' @keywords internal -get_training_data.default <- function(object, env = parent.frame(), - arg = "data") { - - # Throw error message for S4 objects (for now) - if (isS4(object)) { - stop(msg, call. = FALSE) - } - - # Grab the call - mcall <- tryCatch(stats::getCall(object), error = function(e) { - stop(msg, call. = FALSE) - }) - - # If data component of the call is NULL, then try to make sure each - # component is named before proceeding (taken from Advanced R, 2nd ed.) - if (is.null(mcall[[arg]])) { - f <- tryCatch(eval(mcall[[1L]], envir = env), error = function(e) { - stop(msg, call. = FALSE) - }) - if (!is.primitive(f)) { - mcall <- match.call(f, call = mcall) - } - } - - # Grab the data component (if it exists) - n <- 1 - while(length(env) != 0) { - train <- tryCatch(eval(mcall[[arg]], envir = env), error = function(e) { - NULL - }) - if (!is.null(train) || identical(env, globalenv())) { - break - } - env <- parent.frame(n) # inspect calling environment - n <- n + 1 - } - if (is.null(train)) { - stop(msg, call. = FALSE) - } else { - if (!(is.data.frame(train))) { - if (is.matrix(train) || is.list(train)) { - train <- as.data.frame(train) - # } else if (inherits(train, what = "dgCMatrix")) { - # train <- as.data.frame(data.matrix(train)) - } else { - stop(msg, call. = FALSE) - } - } - } - - # Return original training data - train +get_training_data.default <- function(object) { + + # # Throw error message for S4 objects (for now) + # if (isS4(object)) { + # stop(msg, call. = FALSE) + # } + # + # # Grab the call + # mcall <- tryCatch(stats::getCall(object), error = function(e) { + # stop(msg, call. = FALSE) + # }) + # + # # If data component of the call is NULL, then try to make sure each + # # component is named before proceeding (taken from Advanced R, 2nd ed.) + # if (is.null(mcall[[arg]])) { + # f <- tryCatch(eval(mcall[[1L]], envir = env), error = function(e) { + # stop(msg, call. = FALSE) + # }) + # if (!is.primitive(f)) { + # mcall <- match.call(f, call = mcall) + # } + # } + # + # # Grab the data component (if it exists) + # n <- 1 + # while(length(env) != 0) { + # train <- tryCatch(eval(mcall[[arg]], envir = env), error = function(e) { + # NULL + # }) + # if (!is.null(train) || identical(env, globalenv())) { + # break + # } + # env <- parent.frame(n) # inspect calling environment + # n <- n + 1 + # } + # if (is.null(train)) { + # stop(msg, call. = FALSE) + # } else { + # if (!(is.data.frame(train))) { + # if (is.matrix(train) || is.list(train)) { + # train <- as.data.frame(train) + # # } else if (inherits(train, what = "dgCMatrix")) { + # # train <- as.data.frame(data.matrix(train)) + # } else { + # stop(msg, call. = FALSE) + # } + # } + # } + # + # # Return original training data + # train + stop("Training data cannot be extracted from fitted model object. Please ", + "supply the raw training data using the `train` argument.", + call. = FALSE) } @@ -84,52 +86,6 @@ get_training_data.train <- function(object) { } -# Package: C50 ----------------------------------------------------------------- - -#' @keywords internal -get_training_data.C5.0 <- function(object) { - tryCatch( - expr = get_training_data.default(object, arg = "data"), - error = function(e) { - get_training_data.default(object, arg = "x") - } - ) -} - -# Package: Cubist -------------------------------------------------------------- - -#' @keywords internal -get_training_data.cubist <- function(object) { - get_training_data.default(object, arg = "x") -} - - -# Package: e1071 --------------------------------------------------------------- - -#' @keywords internal -get_training_data.svm <- function(object) { - tryCatch( - expr = get_training_data.default(object, arg = "data"), - error = function(e) { - get_training_data.default(object, arg = "x") - } - ) -} - - -# Package: earth --------------------------------------------------------------- - -#' @keywords internal -get_training_data.earth <- function(object) { - tryCatch( - expr = get_training_data.default(object, arg = "data"), - error = function(e) { - get_training_data.default(object, arg = "x") - } - ) -} - - # Package: h2o ----------------------------------------------------------------- #' @keywords internal @@ -167,19 +123,7 @@ get_training_data.RandomForest <- function(object) { object@data@get("input") } - -# Package: randomForest -------------------------------------------------------- - -#' @keywords internal -get_training_data.randomForest <- function(object) { - if (inherits(object, what = "randomForest.formula")) { - get_training_data.default(object, env = parent.frame(), arg = "data") - } else { - get_training_data.default(object, env = parent.frame(), arg = "x") - } -} - - +library # Package: workflow ------------------------------------------------------------ #' @keywords internal diff --git a/R/vi.R b/R/vi.R index b2becd8..ef5af7a 100644 --- a/R/vi.R +++ b/R/vi.R @@ -87,7 +87,7 @@ #' # Compute permutation-based variable importance scores #' set.seed(1434) # for reproducibility #' (vis <- vi(mtcars.ppr, method = "permute", target = "mpg", nsim = 10, -#' metric = "rmse", pred_wrapper = pfun)) +#' metric = "rmse", pred_wrapper = pfun, train = mtcars)) #' #' # Plot variable importance scores #' vip(vis, include_type = TRUE, all_permutations = TRUE, @@ -126,7 +126,7 @@ #' # Permutation-based importance (note that only the predictors that show up #' # in the final tree have non-zero importance) #' set.seed(1046) # for reproducibility -#' vi(tree2, method = "permute", nsim = 10, target = "Class", +#' vi(tree2, method = "permute", nsim = 10, target = "Class", train = bc, #' metric = "logloss", pred_wrapper = pfun, reference_class = "malignant") #' #' # Equivalent (but not sorted) diff --git a/R/vi_permute.R b/R/vi_permute.R index d742128..6516cbf 100644 --- a/R/vi_permute.R +++ b/R/vi_permute.R @@ -142,12 +142,13 @@ #' # Compute permutation-based VI scores #' set.seed(2021) # for reproducibility #' vis <- vi(rfo, method = "permute", target = "y", metric = "rsq", -#' pred_wrapper = pfun) +#' pred_wrapper = pfun, train = trn) #' print(vis) #' #' # Same as above, but using `vi_permute()` directly #' set.seed(2021) # for reproducibility -#' vi_permute(rfo, target = "y", metric = "rsq", pred_wrapper = pfun) +#' vi_permute(rfo, target = "y", metric = "rsq", pred_wrapper = pfun +#' train = trn) #' #' # Plot VI scores (could also replace `vi()` with `vip()` in above example) #' vip(vis, include_type = TRUE) @@ -160,12 +161,12 @@ #' # Permutation-based VIP with user-defined MAE metric #' set.seed(1101) # for reproducibility #' vi_permute(rfo, target = "y", metric = mae, smaller_is_better = TRUE, -#' pred_wrapper = pfun) +#' pred_wrapper = pfun, train = trn) #' #' # Same as above, but using `yardstick` package instead of user-defined metric #' set.seed(1101) # for reproducibility #' vi_permute(rfo, target = "y", metric = yardstick::mae_vec, -#' smaller_is_better = TRUE, pred_wrapper = pfun) +#' smaller_is_better = TRUE, pred_wrapper = pfun, train = trn) #' #' # #' # Classification (binary) example diff --git a/inst/tinytest/test_pkg_C50.R b/inst/tinytest/test_pkg_C50.R index 7b5c771..a642723 100644 --- a/inst/tinytest/test_pkg_C50.R +++ b/inst/tinytest/test_pkg_C50.R @@ -33,16 +33,6 @@ expect_identical( C50::C5imp(fit1, metric = "splits", pct = FALSE)$Overall ) -# Expectations for `get_training_data()` -expect_identical( - current = vip:::get_training_data.C5.0(fit1), - target = friedman2 -) -expect_identical( - current = vip:::get_training_data.C5.0(fit2), - target = subset(friedman2, select = -y) -) - # Expectations for `get_feature_names()` expect_identical( current = vip:::get_feature_names.C5.0(fit1), diff --git a/inst/tinytest/test_pkg_Cubist.R b/inst/tinytest/test_pkg_Cubist.R index b6e2cc4..39e39cf 100644 --- a/inst/tinytest/test_pkg_Cubist.R +++ b/inst/tinytest/test_pkg_Cubist.R @@ -32,12 +32,6 @@ expect_identical( target = vis2[vis1$Variable, , drop = TRUE] ) -# Expectations for `get_training_data()` -expect_identical( - current = vip:::get_training_data.cubist(fit), - target = subset(friedman1, select = -y) -) - # Expectations for `get_feature_names()` expect_identical( current = vip:::get_feature_names.cubist(fit), diff --git a/inst/tinytest/test_pkg_earth.R b/inst/tinytest/test_pkg_earth.R index aff5180..5333b59 100644 --- a/inst/tinytest/test_pkg_earth.R +++ b/inst/tinytest/test_pkg_earth.R @@ -34,12 +34,6 @@ expect_identical( target = unname(vis_earth[, "gcv", drop = TRUE]) ) -# Expectations for `get_training_data()` -expect_identical( - current = vip:::get_training_data.earth(fit), - target = friedman1 -) - # Expectations for `get_feature_names()` expect_identical( current = vip:::get_feature_names.earth(fit), diff --git a/inst/tinytest/test_pkg_glmnet.R b/inst/tinytest/test_pkg_glmnet.R index 208c40f..302bdea 100644 --- a/inst/tinytest/test_pkg_glmnet.R +++ b/inst/tinytest/test_pkg_glmnet.R @@ -58,12 +58,6 @@ expect_identical( target = abs(coef(fit3, s = fit3$lambda[5L])[[1L]][-1L]) ) -# # Expectations for `get_training_data()` -# expect_identical( -# current = vip:::get_training_data.glmnet(fit), -# target = friedman1 -# ) - # Expectations for `get_feature_names()` expect_identical( current = vip:::get_feature_names.glmnet(fit1), diff --git a/inst/tinytest/test_pkg_neuralnet.R b/inst/tinytest/test_pkg_neuralnet.R index e287308..a8ce3a7 100644 --- a/inst/tinytest/test_pkg_neuralnet.R +++ b/inst/tinytest/test_pkg_neuralnet.R @@ -33,12 +33,6 @@ expect_identical( target = NeuralNetTools::garson(fit, bar_plot = FALSE)$rel_imp ) -# Expectations for `get_training_data()` -expect_identical( - current = vip:::get_training_data.default(fit), - target = friedman1 -) - # Expectations for `get_feature_names()` expect_identical( current = vip:::get_feature_names.nn(fit), diff --git a/inst/tinytest/test_pkg_nnet.R b/inst/tinytest/test_pkg_nnet.R index d7cac3a..22b5d74 100644 --- a/inst/tinytest/test_pkg_nnet.R +++ b/inst/tinytest/test_pkg_nnet.R @@ -34,12 +34,6 @@ expect_identical( target = NeuralNetTools::garson(fit, bar_plot = FALSE)$rel_imp ) -# Expectations for `get_training_data()` -expect_identical( - current = vip:::get_training_data.default(fit), - target = friedman1 -) - # Expectations for `get_feature_names()` expect_identical( current = vip:::get_feature_names.nnet(fit), diff --git a/inst/tinytest/test_pkg_randomForest.R b/inst/tinytest/test_pkg_randomForest.R index f85ef88..43e7dd8 100644 --- a/inst/tinytest/test_pkg_randomForest.R +++ b/inst/tinytest/test_pkg_randomForest.R @@ -48,16 +48,6 @@ expect_identical( target = unname(fit4$importance[, "MeanDecreaseGini", drop = TRUE]) ) -# Expectations for `get_training_data()` -expect_identical( - current = vip:::get_training_data.randomForest(fit1), - target = friedman1 -) -expect_identical( # NOTE: Only x is passed in this call - current = vip:::get_training_data.randomForest(fit3), - target = subset(friedman1, select = -y) -) - # Expectations for `get_feature_names()` expect_identical( current = vip:::get_feature_names.randomForest(fit1), diff --git a/inst/tinytest/test_pkg_ranger.R b/inst/tinytest/test_pkg_ranger.R index 20c46ff..2e7c61e 100644 --- a/inst/tinytest/test_pkg_ranger.R +++ b/inst/tinytest/test_pkg_ranger.R @@ -27,12 +27,6 @@ expect_identical( target = unname(fit2$variable.importance) ) -# Expectations for `get_training_data()` -expect_identical( - current = vip:::get_training_data.default(fit1), - target = friedman1 -) - # Expectations for `get_feature_names()` expect_identical( current = vip:::get_feature_names.ranger(fit1), diff --git a/inst/tinytest/test_pkg_rpart.R b/inst/tinytest/test_pkg_rpart.R index 7db1503..df46e32 100644 --- a/inst/tinytest/test_pkg_rpart.R +++ b/inst/tinytest/test_pkg_rpart.R @@ -25,12 +25,6 @@ expect_identical( ) expect_error(vi(no_splits)) -# Expectations for `get_training_data()` -expect_identical( - current = vip:::get_training_data.default(fit), - target = friedman1 -) - # Expectations for `get_feature_names()` expect_identical( current = vip:::get_feature_names.rpart(fit), diff --git a/inst/tinytest/test_pkg_stats.R b/inst/tinytest/test_pkg_stats.R index cd20ffc..de1e165 100644 --- a/inst/tinytest/test_pkg_stats.R +++ b/inst/tinytest/test_pkg_stats.R @@ -22,12 +22,6 @@ expect_identical( target = unname(abs(summary(fit_glm)$coefficients[, "z value"])[-1]) ) -# Expectations for `get_training_data()` -expect_identical( - current = vip:::get_training_data.default(fit_lm), - target = friedman1 -) - # Expectations for `get_feature_names()` expect_identical( current = vip:::get_feature_names.lm(fit_lm), diff --git a/inst/tinytest/test_vi_shap.R b/inst/tinytest/test_vi_shap.R index eda77c7..4cc01a6 100644 --- a/inst/tinytest/test_vi_shap.R +++ b/inst/tinytest/test_vi_shap.R @@ -19,6 +19,6 @@ pfun <- function(object, newdata) { # Compute SHAP-based VI scores set.seed(1511) # for reproducibility -vis1 <- vi_shap(fit1, pred_wrapper = pfun, nsim = 10) -vis2 <- vi_shap(fit2, pred_wrapper = pfun, nsim = 10) -vis3 <- vi(fit1, method = "shap", pred_wrapper = pfun, nsim = 10) +vis1 <- vi_shap(fit1, pred_wrapper = pfun, nsim = 10, train = trn1) +vis2 <- vi_shap(fit2, pred_wrapper = pfun, nsim = 10, train = trn2) +vis3 <- vi(fit1, method = "shap", pred_wrapper = pfun, nsim = 10, train = trn1) diff --git a/man/vi.Rd b/man/vi.Rd index df255ae..7247890 100644 --- a/man/vi.Rd +++ b/man/vi.Rd @@ -104,7 +104,7 @@ pfun <- function(object, newdata) predict(object, newdata = newdata) # Compute permutation-based variable importance scores set.seed(1434) # for reproducibility (vis <- vi(mtcars.ppr, method = "permute", target = "mpg", nsim = 10, - metric = "rmse", pred_wrapper = pfun)) + metric = "rmse", pred_wrapper = pfun, train = mtcars)) # Plot variable importance scores vip(vis, include_type = TRUE, all_permutations = TRUE, @@ -143,7 +143,7 @@ pfun <- function(object, newdata) { # Permutation-based importance (note that only the predictors that show up # in the final tree have non-zero importance) set.seed(1046) # for reproducibility -vi(tree2, method = "permute", nsim = 10, target = "Class", +vi(tree2, method = "permute", nsim = 10, target = "Class", train = bc, metric = "logloss", pred_wrapper = pfun, reference_class = "malignant") # Equivalent (but not sorted) diff --git a/man/vi_permute.Rd b/man/vi_permute.Rd index bec6569..737cf4d 100644 --- a/man/vi_permute.Rd +++ b/man/vi_permute.Rd @@ -165,12 +165,13 @@ rfo <- ranger(y ~ ., data = trn) # Compute permutation-based VI scores set.seed(2021) # for reproducibility vis <- vi(rfo, method = "permute", target = "y", metric = "rsq", - pred_wrapper = pfun) + pred_wrapper = pfun, train = trn) print(vis) # Same as above, but using `vi_permute()` directly set.seed(2021) # for reproducibility -vi_permute(rfo, target = "y", metric = "rsq", pred_wrapper = pfun) +vi_permute(rfo, target = "y", metric = "rsq", pred_wrapper = pfun + train = trn) # Plot VI scores (could also replace `vi()` with `vip()` in above example) vip(vis, include_type = TRUE) @@ -183,12 +184,12 @@ mae <- function(truth, estimate) { # Permutation-based VIP with user-defined MAE metric set.seed(1101) # for reproducibility vi_permute(rfo, target = "y", metric = mae, smaller_is_better = TRUE, - pred_wrapper = pfun) + pred_wrapper = pfun, train = trn) # Same as above, but using `yardstick` package instead of user-defined metric set.seed(1101) # for reproducibility vi_permute(rfo, target = "y", metric = yardstick::mae_vec, - smaller_is_better = TRUE, pred_wrapper = pfun) + smaller_is_better = TRUE, pred_wrapper = pfun, train = trn) # # Classification (binary) example