Skip to content

Commit

Permalink
add lightgbm support
Browse files Browse the repository at this point in the history
  • Loading branch information
bgreenwell committed Jul 15, 2023
1 parent b8bd961 commit 88a8126
Show file tree
Hide file tree
Showing 13 changed files with 4,204 additions and 30 deletions.
8 changes: 4 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ URL: https://github.com/koalaverse/vip/,
BugReports: https://github.com/koalaverse/vip/issues
Encoding: UTF-8
VignetteBuilder: knitr
Depends: R (>= 2.10)
Depends: R (>= 4.10)
Imports:
foreach,
ggplot2 (>= 0.9.0),
stats,
tibble,
yardstick,
utils
utils,
yardstick
Suggests:
bookdown,
DT,
Expand All @@ -51,7 +51,7 @@ Suggests:
pdp,
rmarkdown,
tinytest (>= 1.4.1),
varImp,
varImp
Enhances:
C50,
caret,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ S3method(vi_model,default)
S3method(vi_model,earth)
S3method(vi_model,gbm)
S3method(vi_model,glmnet)
S3method(vi_model,lgb.Booster)
S3method(vi_model,lm)
S3method(vi_model,mixo_pls)
S3method(vi_model,mixo_spls)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ The format of this NEWS file, starting with [vip](https://cran.r-project.org/pac

### Added

* `vi_model()` now supports [lightgbm](https://cran.r-project.org/package=lightgbm) models. Thanks to @nipnipj for the suggestion [(#146)](https://github.com/koalaverse/vip/issues/146).

* The permutation importance method (i.e., function `vi_permute()`) now integrates with and uses [yardstick](https://cran.r-project.org/package=yardstick) performance metrics.

* `list_metrics()` gained an additional `smaller_is_better` column indicating whether or not the corresponding metric should be minimized (`smaller_is_better = TRUE`) or maximized (`smaller_is_better = FALSE`); thanks to @topedo. Additionally, all the column names are now in lower case.
Expand Down
9 changes: 6 additions & 3 deletions R/get_feature_names.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#' Extract model formula
#'
#' Calls \code{\link[stats]{formula}} to extract the formulae from various
#' modeling objects, but returns \code{NULL} instead of an error for objects
#' Calls [formula][stats::formula] to extract the formulae from various
#' modeling objects, but returns `NULL` instead of an error for objects
#' that do not contain a formula component.
#'
#' @param object An appropriate fitted model object.
#'
#' @return Either a \code{\link[stats]{formula}} object or \code{NULL}.
#' @return Either a \code{\link[stats]{formula}} object or `NULL`.
#'
#' @keywords internal
#' @noRd
get_formula <- function(object) {
UseMethod("get_formula")
}
Expand Down
1 change: 1 addition & 0 deletions R/get_training_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ msg <- paste0(


#' @keywords internal
#' @noRd
get_training_data <- function(object) {
UseMethod("get_training_data")
}
Expand Down
48 changes: 48 additions & 0 deletions R/vi_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@
#' the relative number of times each feature has been used throughout each
#' tree in the ensemble.
#'
#' * [lightgbm][lightgbm::lightgbm] - Same as for [xgboost][xgboost::xgboost]
#' models, except [lgb.importance][lightgbm::lgb.importance] (which this method
#' calls internally) has an additional argument, `percentage`, that defaults to
#' `TRUE`, resulting in the VI scores shown as a relative percentage; pass
#' `percentage = FALSE` in the call to `vi_model()` to produce VI scores for
#' [lightgbm][lightgbm::lightgbm] models on the raw scale.
#'
#' @source
#' Johan Bring (1994) How to Standardize Regression Coefficients, The American
#' Statistician, 48:3, 209-213, DOI: 10.1080/00031305.1994.10476059.
Expand Down Expand Up @@ -622,6 +629,47 @@ vi_model.H2ORegressionModel <- function(object, ...) {
}


# Package: lightgbm ------------------------------------------------------------

#' @rdname vi_model
#'
#' @export
vi_model.lgb.Booster <- function(object, type = c("gain", "cover", "frequency"),
...) {

# # Check for dependency
# if (!requireNamespace("xgboost", quietly = TRUE)) {
# stop("Package \"xgboost\" needed for this function to work. Please ",
# "install it.", call. = FALSE)
# }

# Determine which type of variable importance to compute
type <- match.arg(type)

# Construct model-specific variable importance scores
imp <- lightgbm::lgb.importance(model = object, ...)
names(imp) <- tolower(names(imp))
# if ("weight" %in% names(imp)) {
# type <- "weight" # gblinear
# }
vis <- tibble::as_tibble(imp)[, c("feature", type)]
tib <- tibble::tibble(
"Variable" = vis$feature,
"Importance" = vis[[2L]]
)

# Add variable importance type attribute
attr(tib, which = "type") <- type

# Add "vi" class
class(tib) <- c("vi", class(tib))

# Return results
tib

}


# Package: mixOmics -----------------------------------------------------------

#' @rdname vi_model
Expand Down
6 changes: 6 additions & 0 deletions R/vi_permute.R
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ vi_permute.default <- function(
...
) {

# # Check for yardstick package
# if (!requireNamespace("yardstick", quietly = TRUE)) {
# stop("Package \"yardstick\" needed for this function to work. ",
# "Please install it.", call. = FALSE)
# }

# FIXEME: Is there a better way to fix this?
#
# ❯ checking R code for possible problems ... NOTE
Expand Down
2,033 changes: 2,033 additions & 0 deletions inst/tinytest/lightgbm.model

Large diffs are not rendered by default.

57 changes: 57 additions & 0 deletions inst/tinytest/test_pkg_lightgbm.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Exits
if (!requireNamespace("lightgbm", quietly = TRUE)) {
exit_file("Package 'lightgbm' missing")
}

# # Load required packages
# suppressMessages({
# library(lightgbm)
# })

# Basic example using imputed titanic data set
t3 <- titanic_mice[[1L]]

# Fit a simple model
set.seed(1449) # for reproducibility
bst <- lightgbm::lightgbm(
data = data.matrix(subset(t3, select = -survived)),
label = ifelse(t3$survived == "yes", 1, 0),
params = list("objective" = "binary", "force_row_wise" = TRUE),
verbose = 0
)

# Compute VI scores
vi_gain <- vi_model(bst)
vi_cover <- vi_model(bst, type = "cover")
vi_frequency <- vi_model(bst, type = "frequency")
vi_lightgbm <- lightgbm::lgb.importance(model = bst)

# Expectations for `vi_model()`
expect_identical(
current = vi_gain$Importance,
target = vi_lightgbm$Gain
)
expect_identical(
current = vi_cover$Importance,
target = vi_lightgbm$Cover
)
expect_identical(
current = vi_frequency$Importance,
target = vi_lightgbm$Frequency
)
expect_identical(
current = vi_model(bst, percentage = FALSE)$Importance,
target = lightgbm::lgb.importance(bst, percentage = FALSE)$Gain
)

# Expectations for `get_training_data()`
expect_error(vip:::get_training_data.default(bst))

# Call `vip::vip()` directly
p <- vip(bst, method = "model", include_type = TRUE)

# Expect `p` to be a `"gg" "ggplot"` object
expect_identical(
current = class(p),
target = c("gg", "ggplot")
)
8 changes: 4 additions & 4 deletions inst/tinytest/test_pkg_tidymodels.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ if (!requireNamespace("tidymodels", quietly = TRUE)) {
data("bivariate", package = "modeldata")

# Define a 'ranger'-based random forest model
ranger_spec <- parsnip::rand_forest(trees = 1e3, mode = "classification") %>%
ranger_spec <- parsnip::rand_forest(trees = 1e3, mode = "classification") |>
parsnip::set_engine("ranger", importance = "impurity")

# Fit models
set.seed(421) # for reproduicbility
ranger_fit_workflow <- # worflows
workflows::workflow(Class ~ ., ranger_spec) %>%
workflows::workflow(Class ~ ., ranger_spec) |>
parsnip::fit(bivariate_train)
ranger_fit_parsnip <- # parsnip
ranger_spec %>%
ranger_spec |>
parsnip::fit(Class ~ ., data = bivariate_train)

# Extract underlying 'ranger' fits
Expand Down Expand Up @@ -72,7 +72,7 @@ pfun <- function(object, newdata) {

# Compute permutation-based VI scores using AUC metric
set.seed(912) # for reproducibility
vi_auc <- ranger_fit_workflow %>%
vi_auc <- ranger_fit_workflow |>
vi(method = "permute",
target = "Class",
metric = "roc_auc",
Expand Down
Loading

0 comments on commit 88a8126

Please sign in to comment.