Skip to content

Commit

Permalink
tweak tests, etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
bgreenwell committed Jul 15, 2023
1 parent 88a8126 commit 63bdeb8
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 2,050 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
^codecov\.yml$
^data-raw$
^docs$
^lightgbm\.model$
^README\.Rmd$
^README_cache
^revdep
Expand Down
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ URL: https://github.com/koalaverse/vip/,
BugReports: https://github.com/koalaverse/vip/issues
Encoding: UTF-8
VignetteBuilder: knitr
Depends: R (>= 4.10)
Depends: R (>= 4.1.0)
Imports:
foreach,
ggplot2 (>= 0.9.0),
Expand Down Expand Up @@ -60,6 +60,7 @@ Enhances:
gbm,
glmnet,
h2o,
lightgbm,
mixOmics,
mlr,
mlr3,
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format of this NEWS file, starting with [vip](https://cran.r-project.org/pac

### Changed

* Raised R version dependency to >= 4.1.0 (the introduction of the native piper operator `|>`).

* The `vi_permute` function now uses [yardstick](https://cran.r-project.org/package=yardstick); consequently, metric functions now conform to [yardstick](https://cran.r-project.org/package=yardstick) metric argument names.

* The `var_fun` argument in `vi_firm()` has been deprecated; use the new `var_continuous` and `var_categorical` instead.
Expand Down
12 changes: 11 additions & 1 deletion R/get_training_data.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Error message to display when training data cannot be extracted form object
msg <- paste0(
"The training data could not be extracted from object. You can supply the ",
"training data using the `train` argument in the call to `vi_model()`."
"training data using the `train` argument."
)


Expand Down Expand Up @@ -179,3 +179,13 @@ get_training_data.randomForest <- function(object) {
get_training_data.default(object, env = parent.frame(), arg = "x")
}
}


# Package: workflow ------------------------------------------------------------

#' @keywords internal
get_training_data.workflow <- function(object) {
stop("Training data cannot be extracted from workflow objects. Please ",
"supply the raw training data using the `train` argument.",
call. = FALSE)
}
25 changes: 17 additions & 8 deletions R/vi_firm.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@
#' [Greenwell et al. (2018)](https://arxiv.org/abs/1805.04755) for details and
#' additional examples.
#'
#' @note This approach can provide misleading results in the presence of
#' interaction effects (akin to interpreting main effect coefficients in a
#' linear with higher level interaction effects).
#'
#' @references
#' J. H. Friedman. Greedy function approximation: A gradient boosting machine.
#' *Annals of Statistics*, **29**: 1189-1232, 2001.
Expand Down Expand Up @@ -90,23 +94,28 @@
#' # Fit a projection pursuit regression model
#' mtcars.ppr <- ppr(mpg ~ ., data = mtcars, nterms = 1)
#'
#' # Compute variable importance scores using the FIRM method; note the the pdp
#' # Compute variable importance scores using the FIRM method; note that the pdp
#' # package knows how to work with a "ppr" object, so there's no need to pass
#' # the training data or a prediction wrapper, but it's good practice.
#' vi_firm(mtcars.ppr, train = mtcars)
#'
#' # Define prediction wrapper
#' pfun <- function(object, newdata) { # use PD
#' mean(predict(object, newdata = newdata)) # return averaged prediction
#' # For unsopported models, need to define a prediction wrapper; this approach
#' # will work for ANY model (supported or unsupported, so better to just always
#' # define it pass it)
#' pfun <- function(object, newdata) {
#' # To use partial dependence, this function needs to return the AVERAGE
#' # prediction (for ICE, simply omit the averaging step)
#' mean(predict(object, newdata = newdata))
#' }
#'
#' # Equivalent to the previous results
#' vi_firm(mtcars.ppr, train = mtcars, pred.fun = pfun)
#' # Equivalent to the previous results (but would work if this type of model
#' # was not explicitly supported)
#' vi_firm(mtcars.ppr, pred.fun = pfun, train = mtcars)
#'
#' # Equivalent VI scores, but the output is sorted by default
#' vi(mtcars.ppr, method = "firm")
#'
#' # Use MAD to estimate variability for the continuous feature effects
#' # Use MAD to estimate variability of the partial dependence values
#' vi_firm(mtcars.ppr, var_continuous = stats::mad)
#'
#' # Plot VI scores
Expand Down Expand Up @@ -156,7 +165,7 @@ vi_firm.default <- function(
# Construct PD/ICE-based variable importance scores
vis <- lapply(feature_names, function(x) {
firm(object, feature_name = x, var_continuous = var_continuous,
var_categorical = var_categorical, ...)
var_categorical = var_categorical, train = train, ...)
})
# vis <- numeric(length(feature_names)) # loses "effects" attribute
# for (i in seq_along(feature_names)) {
Expand Down
28 changes: 28 additions & 0 deletions R/vi_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,34 @@
#' @rdname vi_model
#'
#' @export
#'
#' @examples
#' \dontrun{
#' # Basic example using imputed titanic data set
#' t3 <- titanic_mice[[1L]]
#'
#' # Fit a simple model
#' set.seed(1449) # for reproducibility
#' bst <- lightgbm::lightgbm(
#' data = data.matrix(subset(t3, select = -survived)),
#' label = ifelse(t3$survived == "yes", 1, 0),
#' params = list("objective" = "binary", "force_row_wise" = TRUE),
#' verbose = 0
#' )
#'
#' # Compute VI scores
#' vi(bst) # defaults to `method = "model"`
#' vi_model(bst) # same as above
#'
#' # Same as above (since default is `method = "model"`), but returns a plot
#' vip(bst, geom = "point")
#' vi_model(bst, type = "cover")
#' vi_model(bst, type = "cover", percentage = FALSE)
#'
#' # Compare to
#' lightgbm::lgb.importance(bst)
#' }
#'
vi_model <- function(object, ...) {
UseMethod("vi_model")
}
Expand Down
19 changes: 19 additions & 0 deletions inst/tinytest/test_pkg_tidymodels.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,22 @@ vi_auc <- ranger_fit_workflow |>

# Not always the case, but here we can expect these to be in (0, 1)
expect_true(all(vi_auc$Importance > 0 & vi_auc$Importance < 1))


################################################################################
# FIRM-based (i.e., model-agnostic) variable importance
################################################################################

pfun <- function(object, newdata) {
mean(predict(object, new_data = newdata, type = "prob")[[".pred_One"]])
}
vi_pd <- ranger_fit_workflow |>
vi_firm(
feature_names = c("A", "B"), # required
train = bivariate_train, # required
# pdp::partial()-specific arguments
pred.fun = pfun
)

# Not always the case, but here we can expect these to be in (0, 1)
expect_true(all(vi_pd$Importance > 0))
Loading

0 comments on commit 63bdeb8

Please sign in to comment.