Skip to content

Commit

Permalink
Merge pull request #62 from bgreenwell/devel
Browse files Browse the repository at this point in the history
fix print issues
  • Loading branch information
bgreenwell committed May 5, 2023
2 parents 23cc419 + 0de5b3d commit fd6858a
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 123 deletions.
2 changes: 1 addition & 1 deletion .Rproj.user/shared/notebooks/paths
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/Users/bgreenwell/.R/Makevars="EEC1896A"
/Users/bgreenwell/Dropbox/devel/fastshap/.Rbuildignore="B0549DA2"
/Users/bgreenwell/Dropbox/devel/fastshap/DESCRIPTION="300503D2"
/Users/bgreenwell/Dropbox/devel/fastshap/NAMESPACE="0716B2F8"
Expand All @@ -22,6 +23,5 @@
/Users/bgreenwell/Dropbox/devel/fastshap/slowtests/fastshap-genOMat.cpp="99FEC81E"
/Users/bgreenwell/Dropbox/devel/fastshap/slowtests/slowtest-benchmark.R="29ADFB84"
/Users/bgreenwell/Dropbox/devel/fastshap/slowtests/slowtest-parallel.R="7B058F98"
/Users/bgreenwell/Dropbox/devel/fastshap/slowtests/test-shapviz.R="7F402C66"
/Users/bgreenwell/Dropbox/devel/fastshap/vignettes/fastshap.Rmd="536A2979"
/Users/bgreenwell/Dropbox/trees/book.tex="4ECC8BA9"
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ S3method(explain,default)
S3method(explain,lgb.Booster)
S3method(explain,lm)
S3method(explain,xgb.Booster)
S3method(print,explain)
export(explain)
export(gen_friedman)
importFrom(Rcpp,sourceCpp)
Expand Down
21 changes: 14 additions & 7 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ explain.default <- function(object, feature_names = NULL, X = NULL, nsim = 1,
newdata.stacked <- newdata[rep(1L, times = nsim), ] # replicate obs `nsim` times
phis <- explain.default(object, feature_names = feature_names, X = X,
nsim = 1L, pred_wrapper = pred_wrapper,
newdata = newdata.stacked, adjust = FALSE, ...)
newdata = newdata.stacked, adjust = FALSE,
parallel = parallel, ...)
phi.avg <- t(colMeans(phis)) # transpose to keep as row matrix
if (isTRUE(adjust)) {
# Adjust sum of approximate Shapley values using the same technique from
Expand Down Expand Up @@ -408,7 +409,8 @@ explain.default <- function(object, feature_names = NULL, X = NULL, nsim = 1,
#' @export
explain.lm <- function(object, feature_names = NULL, X, nsim = 1,
pred_wrapper, newdata = NULL, adjust = FALSE,
exact = FALSE, shap_only = TRUE, ...) {
exact = FALSE, baseline = NULL, shap_only = TRUE,
parallel = FALSE, ...) {
if (isTRUE(exact)) { # use Linear SHAP
phis <- if (is.null(newdata)) {
stats::predict(object, type = "terms", ...)
Expand All @@ -429,7 +431,8 @@ explain.lm <- function(object, feature_names = NULL, X, nsim = 1,
} else {
explain.default(object, feature_names = feature_names, X = X, nsim = nsim,
pred_wrapper = pred_wrapper, newdata = newdata,
adjust = adjust, ...)
adjust = adjust, baseline = baseline, shap_only = shap_only,
parallel = parallel, ...)
}
}

Expand All @@ -439,7 +442,8 @@ explain.lm <- function(object, feature_names = NULL, X, nsim = 1,
#' @export
explain.xgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1,
pred_wrapper, newdata = NULL, adjust = FALSE,
exact = FALSE, shap_only = TRUE, ...) {
exact = FALSE, baseline = NULL,
shap_only = TRUE, parallel = FALSE, ...) {
if (isTRUE(exact)) { # use Tree SHAP
if (is.null(X) && is.null(newdata)) {
stop("Must supply `X` or `newdata` argument (but not both).",
Expand All @@ -463,7 +467,8 @@ explain.xgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1
} else {
explain.default(object, feature_names = feature_names, X = X, nsim = nsim,
pred_wrapper = pred_wrapper, newdata = newdata,
adjust = adjust, ...)
adjust = adjust, baseline = baseline, shap_only = shap_only,
parallel = parallel, ...)
}
}

Expand All @@ -473,7 +478,8 @@ explain.xgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1
#' @export
explain.lgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1,
pred_wrapper, newdata = NULL, adjust = FALSE,
exact = FALSE, shap_only = TRUE, ...) {
exact = FALSE, baseline = NULL,
shap_only = TRUE, parallel = FALSE, ...) {
if (isTRUE(exact)) { # use Tree SHAP
if (is.null(X) && is.null(newdata)) {
stop("Must supply `X` or `newdata` argument (but not both).",
Expand Down Expand Up @@ -504,6 +510,7 @@ explain.lgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1
} else {
explain.default(object, feature_names = feature_names, X = X, nsim = nsim,
pred_wrapper = pred_wrapper, newdata = newdata,
adjust = adjust, shap_only = shap_only, ...)
adjust = adjust, baseline = baseline, shap_only = shap_only,
parallel = parallel, ...)
}
}
17 changes: 10 additions & 7 deletions R/print.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#' @keywords internal
#'
#' @export
print.explain <- function(x, ...) {
print(data.matrix(as.data.frame(x)))
invisible(x)
}
#' #' @keywords internal
#' #'
#' #' @export
#' print.explain <- function(x, ...) {
#' if (is.matrix(x)) {
#' x <- data.matrix(as.data.frame(x))
#' }
#' print(x)
#' invisible(x)
#' }
5 changes: 4 additions & 1 deletion inst/tinytest/test-adjust.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ if (!requireNamespace("xgboost", quietly = TRUE)) {
exit_file("Package 'xgboost' missing")
}

library(fastshap)

# Use one of the available (imputed) versions of the Titanic data
titanic <- titanic_mice[[1L]]

Expand Down Expand Up @@ -35,7 +37,8 @@ jack.dawson <- data.matrix(data.frame(
params.lgb <- list(
num_leaves = 4L,
learning_rate = 0.1,
objective = "binary"
objective = "binary",
force_row_wise = TRUE
)

set.seed(1420) # for reproducibility
Expand Down
55 changes: 55 additions & 0 deletions inst/tinytest/test-shapviz.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Exits
if (!requireNamespace("shapviz", quietly = TRUE)) {
exit_file("Package shapviz missing")
}
if (!requireNamespace("ranger", quietly = TRUE)) {
exit_file("Package ranger missing")
}

library(shapviz)

# Read in the data and clean it up a bit
set.seed(2220) # for reproducibility
trn <- gen_friedman(500)
tst <- gen_friedman(10)

# Features only
X <- subset(trn, select = -y)
newX <- subset(tst, select = -y)

# Fit a default random forest
set.seed(2222) # for reproducibility
rfo <- ranger::ranger(y ~ ., data = trn)

# Prediction wrapper
pfun <- function(object, newdata) {
predict(object, data = newdata)$predictions
}

# Generate explanations for test set
set.seed(2024) # for reproducibility
ex1 <- explain(rfo, X = X, newdata = newX, pred_wrapper = pfun, adjust = TRUE,
nsim = 50)

# Same, but set `shap_only = FALSE` for convenience with shapviz
set.seed(2024) # for reproducibility
ex2 <- explain(rfo, X = X, newdata = newX, pred_wrapper = pfun, adjust = TRUE,
nsim = 50, shap_only = FALSE)

# Create "shapviz" objects
shv1 <- shapviz(ex1, X = newX)
shv2 <- shapviz(ex2)
shv3 <- shapviz(ex2$shapley_values, X = newX, baseline = ex2$baseline)

# Expectations
expect_error(shapviz(ex1))
expect_identical(ex2$baseline, mean(pfun(rfo, X)))
expect_identical(shv1$X, shv2$X)
expect_identical(shv1$X, shv3$X)
expect_identical(shv1$baseline, shv2$baseline)
expect_identical(shv1$baseline, shv3$baseline)

# # SHAP waterfall plots
# sv_waterfall(shv1, row_id = 1)
# sv_waterfall(shv2, row_id = 1)
# sv_waterfall(shv3, row_id = 1)
6 changes: 6 additions & 0 deletions man/explain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

102 changes: 0 additions & 102 deletions slowtests/test-shapviz.R

This file was deleted.

2 changes: 0 additions & 2 deletions vignettes/fastshap.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ To illustrate, we'll use the `explain()` function to estimate how each of jack f


```r
library(fastshap)

X <- subset(t1, select = -survived) # features only
set.seed(2113) # for reproducibility
(ex.jack <- explain(rfo, X = X, pred_wrapper = pfun, newdata = jack.dawson))
Expand Down
2 changes: 0 additions & 2 deletions vignettes/fastshap.Rmd.orig
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ To illustrate, we'll use the `explain()` function to estimate how each of jack f
[^2]: Note that we need to supply the training features via the `X` argument (i.e., no response column) and that `newdata` should also only contain columns of feature values.

```{r titanic-explain-jack}
library(fastshap)

X <- subset(t1, select = -survived) # features only
set.seed(2113) # for reproducibility
(ex.jack <- explain(rfo, X = X, pred_wrapper = pfun, newdata = jack.dawson))
Expand Down

0 comments on commit fd6858a

Please sign in to comment.