Skip to content

Commit

Permalink
Merge pull request #149 from koalaverse/devel
Browse files Browse the repository at this point in the history
fix CITATION and other tweaks for CRAN
  • Loading branch information
bgreenwell committed Jul 17, 2023
2 parents 0afae9d + 05ede55 commit 33cc92e
Show file tree
Hide file tree
Showing 23 changed files with 93 additions and 208 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* This NEWS file now follows the [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) format.

* The training data has to be explicitly passed in more cases.

* Raised R version dependency to >= 4.1.0 (the introduction of the native piper operator `|>`).

* The `vi_permute` function now uses [yardstick](https://cran.r-project.org/package=yardstick); consequently, metric functions now conform to [yardstick](https://cran.r-project.org/package=yardstick) metric argument names.
Expand Down
173 changes: 58 additions & 115 deletions R/get_training_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,65 +6,66 @@ msg <- paste0(


#' @keywords internal
#' @noRd
get_training_data <- function(object) {
UseMethod("get_training_data")
}


#' @keywords internal
get_training_data.default <- function(object, env = parent.frame(),
arg = "data") {

# Throw error message for S4 objects (for now)
if (isS4(object)) {
stop(msg, call. = FALSE)
}

# Grab the call
mcall <- tryCatch(stats::getCall(object), error = function(e) {
stop(msg, call. = FALSE)
})

# If data component of the call is NULL, then try to make sure each
# component is named before proceeding (taken from Advanced R, 2nd ed.)
if (is.null(mcall[[arg]])) {
f <- tryCatch(eval(mcall[[1L]], envir = env), error = function(e) {
stop(msg, call. = FALSE)
})
if (!is.primitive(f)) {
mcall <- match.call(f, call = mcall)
}
}

# Grab the data component (if it exists)
n <- 1
while(length(env) != 0) {
train <- tryCatch(eval(mcall[[arg]], envir = env), error = function(e) {
NULL
})
if (!is.null(train) || identical(env, globalenv())) {
break
}
env <- parent.frame(n) # inspect calling environment
n <- n + 1
}
if (is.null(train)) {
stop(msg, call. = FALSE)
} else {
if (!(is.data.frame(train))) {
if (is.matrix(train) || is.list(train)) {
train <- as.data.frame(train)
# } else if (inherits(train, what = "dgCMatrix")) {
# train <- as.data.frame(data.matrix(train))
} else {
stop(msg, call. = FALSE)
}
}
}

# Return original training data
train
get_training_data.default <- function(object) {

# # Throw error message for S4 objects (for now)
# if (isS4(object)) {
# stop(msg, call. = FALSE)
# }
#
# # Grab the call
# mcall <- tryCatch(stats::getCall(object), error = function(e) {
# stop(msg, call. = FALSE)
# })
#
# # If data component of the call is NULL, then try to make sure each
# # component is named before proceeding (taken from Advanced R, 2nd ed.)
# if (is.null(mcall[[arg]])) {
# f <- tryCatch(eval(mcall[[1L]], envir = env), error = function(e) {
# stop(msg, call. = FALSE)
# })
# if (!is.primitive(f)) {
# mcall <- match.call(f, call = mcall)
# }
# }
#
# # Grab the data component (if it exists)
# n <- 1
# while(length(env) != 0) {
# train <- tryCatch(eval(mcall[[arg]], envir = env), error = function(e) {
# NULL
# })
# if (!is.null(train) || identical(env, globalenv())) {
# break
# }
# env <- parent.frame(n) # inspect calling environment
# n <- n + 1
# }
# if (is.null(train)) {
# stop(msg, call. = FALSE)
# } else {
# if (!(is.data.frame(train))) {
# if (is.matrix(train) || is.list(train)) {
# train <- as.data.frame(train)
# # } else if (inherits(train, what = "dgCMatrix")) {
# # train <- as.data.frame(data.matrix(train))
# } else {
# stop(msg, call. = FALSE)
# }
# }
# }
#
# # Return original training data
# train
stop("Training data cannot be extracted from fitted model object. Please ",
"supply the raw training data using the `train` argument.",
call. = FALSE)

}

Expand All @@ -85,68 +86,22 @@ get_training_data.train <- function(object) {
}


# Package: C50 -----------------------------------------------------------------

#' @keywords internal
get_training_data.C5.0 <- function(object) {
tryCatch(
expr = get_training_data.default(object, arg = "data"),
error = function(e) {
get_training_data.default(object, arg = "x")
}
)
}

# Package: Cubist --------------------------------------------------------------

#' @keywords internal
get_training_data.cubist <- function(object) {
get_training_data.default(object, arg = "x")
}


# Package: e1071 ---------------------------------------------------------------

#' @keywords internal
get_training_data.svm <- function(object) {
tryCatch(
expr = get_training_data.default(object, arg = "data"),
error = function(e) {
get_training_data.default(object, arg = "x")
}
)
}


# Package: earth ---------------------------------------------------------------

#' @keywords internal
get_training_data.earth <- function(object) {
tryCatch(
expr = get_training_data.default(object, arg = "data"),
error = function(e) {
get_training_data.default(object, arg = "x")
}
)
}


# Package: h2o -----------------------------------------------------------------

#' @keywords internal
get_training_data.H2OBinomialModel <- function(object, ...) {
get_training_data.H2OBinomialModel <- function(object) {
as.data.frame(h2o::h2o.getFrame(object@allparameters$training_frame))
}


#' @keywords internal
get_training_data.H2OMultinomialModel <- function(object, ...) {
get_training_data.H2OMultinomialModel <- function(object) {
as.data.frame(h2o::h2o.getFrame(object@allparameters$training_frame))
}


#' @keywords internal
get_training_data.H2ORegressionModel <- function(object, ...) {
get_training_data.H2ORegressionModel <- function(object) {
as.data.frame(h2o::h2o.getFrame(object@allparameters$training_frame))
}

Expand All @@ -168,19 +123,7 @@ get_training_data.RandomForest <- function(object) {
object@data@get("input")
}


# Package: randomForest --------------------------------------------------------

#' @keywords internal
get_training_data.randomForest <- function(object) {
if (inherits(object, what = "randomForest.formula")) {
get_training_data.default(object, env = parent.frame(), arg = "data")
} else {
get_training_data.default(object, env = parent.frame(), arg = "x")
}
}


library
# Package: workflow ------------------------------------------------------------

#' @keywords internal
Expand Down
4 changes: 2 additions & 2 deletions R/vi.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
#' # Compute permutation-based variable importance scores
#' set.seed(1434) # for reproducibility
#' (vis <- vi(mtcars.ppr, method = "permute", target = "mpg", nsim = 10,
#' metric = "rmse", pred_wrapper = pfun))
#' metric = "rmse", pred_wrapper = pfun, train = mtcars))
#'
#' # Plot variable importance scores
#' vip(vis, include_type = TRUE, all_permutations = TRUE,
Expand Down Expand Up @@ -126,7 +126,7 @@
#' # Permutation-based importance (note that only the predictors that show up
#' # in the final tree have non-zero importance)
#' set.seed(1046) # for reproducibility
#' vi(tree2, method = "permute", nsim = 10, target = "Class",
#' vi(tree2, method = "permute", nsim = 10, target = "Class", train = bc,
#' metric = "logloss", pred_wrapper = pfun, reference_class = "malignant")
#'
#' # Equivalent (but not sorted)
Expand Down
9 changes: 5 additions & 4 deletions R/vi_permute.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,13 @@
#' # Compute permutation-based VI scores
#' set.seed(2021) # for reproducibility
#' vis <- vi(rfo, method = "permute", target = "y", metric = "rsq",
#' pred_wrapper = pfun)
#' pred_wrapper = pfun, train = trn)
#' print(vis)
#'
#' # Same as above, but using `vi_permute()` directly
#' set.seed(2021) # for reproducibility
#' vi_permute(rfo, target = "y", metric = "rsq", pred_wrapper = pfun)
#' vi_permute(rfo, target = "y", metric = "rsq", pred_wrapper = pfun
#' train = trn)
#'
#' # Plot VI scores (could also replace `vi()` with `vip()` in above example)
#' vip(vis, include_type = TRUE)
Expand All @@ -160,12 +161,12 @@
#' # Permutation-based VIP with user-defined MAE metric
#' set.seed(1101) # for reproducibility
#' vi_permute(rfo, target = "y", metric = mae, smaller_is_better = TRUE,
#' pred_wrapper = pfun)
#' pred_wrapper = pfun, train = trn)
#'
#' # Same as above, but using `yardstick` package instead of user-defined metric
#' set.seed(1101) # for reproducibility
#' vi_permute(rfo, target = "y", metric = yardstick::mae_vec,
#' smaller_is_better = TRUE, pred_wrapper = pfun)
#' smaller_is_better = TRUE, pred_wrapper = pfun, train = trn)
#'
#' #
#' # Classification (binary) example
Expand Down
4 changes: 2 additions & 2 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ knitr::opts_chunk$set(
# vip: Variable Importance Plots <img src="man/figures/logo-vip.png" align="right" width="130" height="150" />

<!-- badges: start -->
[![CRAN_Status_Badge](http://www.r-pkg.org/badges/version/vip)](https://cran.r-project.org/package=vip)
[![CRAN_Status_Badge](https://www.r-pkg.org/badges/version/vip)](https://cran.r-project.org/package=vip)
[![R-CMD-check](https://github.com/koalaverse/vip/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/koalaverse/vip/actions/workflows/R-CMD-check.yaml)
[![Coverage Status](https://codecov.io/gh/koalaverse/vip/graph/badge.svg)](https://app.codecov.io/github/koalaverse/vip?branch=master)
[![Lifecycle: stable](https://img.shields.io/badge/lifecycle-stable-brightgreen.svg)](https://lifecycle.r-lib.org/articles/stages.html#stable)
[![Downloads](http://cranlogs.r-pkg.org/badges/vip)](http://cran.r-project.org/package=vip/)
[![Downloads](https://cranlogs.r-pkg.org/badges/vip)](https://cran.r-project.org/package=vip/)
[![The R Journal](https://img.shields.io/badge/The%20R%20Journal-10.32614%2FRJ--2020--013-brightgreen)](https://doi.org/10.32614/RJ-2020-013)
<!-- badges: end -->

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

<!-- badges: start -->

[![CRAN_Status_Badge](http://www.r-pkg.org/badges/version/vip)](https://cran.r-project.org/package=vip)
[![CRAN_Status_Badge](https://www.r-pkg.org/badges/version/vip)](https://cran.r-project.org/package=vip)
[![R-CMD-check](https://github.com/koalaverse/vip/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/koalaverse/vip/actions/workflows/R-CMD-check.yaml)
[![Coverage
Status](https://codecov.io/gh/koalaverse/vip/graph/badge.svg)](https://app.codecov.io/github/koalaverse/vip?branch=master)
[![Lifecycle:
stable](https://img.shields.io/badge/lifecycle-stable-brightgreen.svg)](https://lifecycle.r-lib.org/articles/stages.html#stable)
[![Downloads](http://cranlogs.r-pkg.org/badges/vip)](http://cran.r-project.org/package=vip/)
[![Downloads](https://cranlogs.r-pkg.org/badges/vip)](https://cran.r-project.org/package=vip/)
[![The R
Journal](https://img.shields.io/badge/The%20R%20Journal-10.32614%2FRJ--2020--013-brightgreen)](https://doi.org/10.32614/RJ-2020-013)
<!-- badges: end -->
Expand Down
4 changes: 2 additions & 2 deletions inst/CITATION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
citHeader("To cite vip in publications use:")

citEntry(entry = "Article",
bibentry("Article",
title = "Variable Importance Plots---An Introduction to the vip Package",
author = personList(as.person("Brandon M. Greenwell"), as.person("Bradley C. Boehmke")),
author = c(as.person("Brandon M. Greenwell"), as.person("Bradley C. Boehmke")),
journal = "The R Journal",
year = "2020",
volume = "12",
Expand Down
4 changes: 2 additions & 2 deletions inst/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ @Article{party2007a
year = {2007},
volume = {8},
number = {25},
url = {https://www.biomedcentral.com/1471-2105/8/25},
url = {https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-8-25},
}
@Article{party2008b,
title = {Conditional Variable Importance for Random Forests},
Expand All @@ -501,7 +501,7 @@ @Article{party2008b
year = {2008},
volume = {9},
number = {307},
url = {https://www.biomedcentral.com/1471-2105/9/307},
url = {https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-9-307},
}
@Article{partykit2015,
title = {{partykit}: A Modular Toolkit for Recursive Partytioning in {R}},
Expand Down
10 changes: 0 additions & 10 deletions inst/tinytest/test_pkg_C50.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@ expect_identical(
C50::C5imp(fit1, metric = "splits", pct = FALSE)$Overall
)

# Expectations for `get_training_data()`
expect_identical(
current = vip:::get_training_data.C5.0(fit1),
target = friedman2
)
expect_identical(
current = vip:::get_training_data.C5.0(fit2),
target = subset(friedman2, select = -y)
)

# Expectations for `get_feature_names()`
expect_identical(
current = vip:::get_feature_names.C5.0(fit1),
Expand Down
6 changes: 0 additions & 6 deletions inst/tinytest/test_pkg_Cubist.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ expect_identical(
target = vis2[vis1$Variable, , drop = TRUE]
)

# Expectations for `get_training_data()`
expect_identical(
current = vip:::get_training_data.cubist(fit),
target = subset(friedman1, select = -y)
)

# Expectations for `get_feature_names()`
expect_identical(
current = vip:::get_feature_names.cubist(fit),
Expand Down
6 changes: 0 additions & 6 deletions inst/tinytest/test_pkg_earth.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ expect_identical(
target = unname(vis_earth[, "gcv", drop = TRUE])
)

# Expectations for `get_training_data()`
expect_identical(
current = vip:::get_training_data.earth(fit),
target = friedman1
)

# Expectations for `get_feature_names()`
expect_identical(
current = vip:::get_feature_names.earth(fit),
Expand Down
6 changes: 0 additions & 6 deletions inst/tinytest/test_pkg_glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@ expect_identical(
target = abs(coef(fit3, s = fit3$lambda[5L])[[1L]][-1L])
)

# # Expectations for `get_training_data()`
# expect_identical(
# current = vip:::get_training_data.glmnet(fit),
# target = friedman1
# )

# Expectations for `get_feature_names()`
expect_identical(
current = vip:::get_feature_names.glmnet(fit1),
Expand Down
6 changes: 0 additions & 6 deletions inst/tinytest/test_pkg_neuralnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ expect_identical(
target = NeuralNetTools::garson(fit, bar_plot = FALSE)$rel_imp
)

# Expectations for `get_training_data()`
expect_identical(
current = vip:::get_training_data.default(fit),
target = friedman1
)

# Expectations for `get_feature_names()`
expect_identical(
current = vip:::get_feature_names.nn(fit),
Expand Down
Loading

0 comments on commit 33cc92e

Please sign in to comment.