Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix print issues #62

Merged
merged 1 commit into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .Rproj.user/shared/notebooks/paths
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/Users/bgreenwell/.R/Makevars="EEC1896A"
/Users/bgreenwell/Dropbox/devel/fastshap/.Rbuildignore="B0549DA2"
/Users/bgreenwell/Dropbox/devel/fastshap/DESCRIPTION="300503D2"
/Users/bgreenwell/Dropbox/devel/fastshap/NAMESPACE="0716B2F8"
Expand All @@ -22,6 +23,5 @@
/Users/bgreenwell/Dropbox/devel/fastshap/slowtests/fastshap-genOMat.cpp="99FEC81E"
/Users/bgreenwell/Dropbox/devel/fastshap/slowtests/slowtest-benchmark.R="29ADFB84"
/Users/bgreenwell/Dropbox/devel/fastshap/slowtests/slowtest-parallel.R="7B058F98"
/Users/bgreenwell/Dropbox/devel/fastshap/slowtests/test-shapviz.R="7F402C66"
/Users/bgreenwell/Dropbox/devel/fastshap/vignettes/fastshap.Rmd="536A2979"
/Users/bgreenwell/Dropbox/trees/book.tex="4ECC8BA9"
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ S3method(explain,default)
S3method(explain,lgb.Booster)
S3method(explain,lm)
S3method(explain,xgb.Booster)
S3method(print,explain)
export(explain)
export(gen_friedman)
importFrom(Rcpp,sourceCpp)
Expand Down
21 changes: 14 additions & 7 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ explain.default <- function(object, feature_names = NULL, X = NULL, nsim = 1,
newdata.stacked <- newdata[rep(1L, times = nsim), ] # replicate obs `nsim` times
phis <- explain.default(object, feature_names = feature_names, X = X,
nsim = 1L, pred_wrapper = pred_wrapper,
newdata = newdata.stacked, adjust = FALSE, ...)
newdata = newdata.stacked, adjust = FALSE,
parallel = parallel, ...)
phi.avg <- t(colMeans(phis)) # transpose to keep as row matrix
if (isTRUE(adjust)) {
# Adjust sum of approximate Shapley values using the same technique from
Expand Down Expand Up @@ -408,7 +409,8 @@ explain.default <- function(object, feature_names = NULL, X = NULL, nsim = 1,
#' @export
explain.lm <- function(object, feature_names = NULL, X, nsim = 1,
pred_wrapper, newdata = NULL, adjust = FALSE,
exact = FALSE, shap_only = TRUE, ...) {
exact = FALSE, baseline = NULL, shap_only = TRUE,
parallel = FALSE, ...) {
if (isTRUE(exact)) { # use Linear SHAP
phis <- if (is.null(newdata)) {
stats::predict(object, type = "terms", ...)
Expand All @@ -429,7 +431,8 @@ explain.lm <- function(object, feature_names = NULL, X, nsim = 1,
} else {
explain.default(object, feature_names = feature_names, X = X, nsim = nsim,
pred_wrapper = pred_wrapper, newdata = newdata,
adjust = adjust, ...)
adjust = adjust, baseline = baseline, shap_only = shap_only,
parallel = parallel, ...)
}
}

Expand All @@ -439,7 +442,8 @@ explain.lm <- function(object, feature_names = NULL, X, nsim = 1,
#' @export
explain.xgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1,
pred_wrapper, newdata = NULL, adjust = FALSE,
exact = FALSE, shap_only = TRUE, ...) {
exact = FALSE, baseline = NULL,
shap_only = TRUE, parallel = FALSE, ...) {
if (isTRUE(exact)) { # use Tree SHAP
if (is.null(X) && is.null(newdata)) {
stop("Must supply `X` or `newdata` argument (but not both).",
Expand All @@ -463,7 +467,8 @@ explain.xgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1
} else {
explain.default(object, feature_names = feature_names, X = X, nsim = nsim,
pred_wrapper = pred_wrapper, newdata = newdata,
adjust = adjust, ...)
adjust = adjust, baseline = baseline, shap_only = shap_only,
parallel = parallel, ...)
}
}

Expand All @@ -473,7 +478,8 @@ explain.xgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1
#' @export
explain.lgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1,
pred_wrapper, newdata = NULL, adjust = FALSE,
exact = FALSE, shap_only = TRUE, ...) {
exact = FALSE, baseline = NULL,
shap_only = TRUE, parallel = FALSE, ...) {
if (isTRUE(exact)) { # use Tree SHAP
if (is.null(X) && is.null(newdata)) {
stop("Must supply `X` or `newdata` argument (but not both).",
Expand Down Expand Up @@ -504,6 +510,7 @@ explain.lgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1
} else {
explain.default(object, feature_names = feature_names, X = X, nsim = nsim,
pred_wrapper = pred_wrapper, newdata = newdata,
adjust = adjust, shap_only = shap_only, ...)
adjust = adjust, baseline = baseline, shap_only = shap_only,
parallel = parallel, ...)
}
}
17 changes: 10 additions & 7 deletions R/print.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#' @keywords internal
#'
#' @export
print.explain <- function(x, ...) {
print(data.matrix(as.data.frame(x)))
invisible(x)
}
#' #' @keywords internal
#' #'
#' #' @export
#' print.explain <- function(x, ...) {
#' if (is.matrix(x)) {
#' x <- data.matrix(as.data.frame(x))
#' }
#' print(x)
#' invisible(x)
#' }
5 changes: 4 additions & 1 deletion inst/tinytest/test-adjust.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ if (!requireNamespace("xgboost", quietly = TRUE)) {
exit_file("Package 'xgboost' missing")
}

library(fastshap)

# Use one of the available (imputed) versions of the Titanic data
titanic <- titanic_mice[[1L]]

Expand Down Expand Up @@ -35,7 +37,8 @@ jack.dawson <- data.matrix(data.frame(
params.lgb <- list(
num_leaves = 4L,
learning_rate = 0.1,
objective = "binary"
objective = "binary",
force_row_wise = TRUE
)

set.seed(1420) # for reproducibility
Expand Down
55 changes: 55 additions & 0 deletions inst/tinytest/test-shapviz.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Exits
if (!requireNamespace("shapviz", quietly = TRUE)) {
exit_file("Package shapviz missing")
}
if (!requireNamespace("ranger", quietly = TRUE)) {
exit_file("Package ranger missing")
}

library(shapviz)

# Read in the data and clean it up a bit
set.seed(2220) # for reproducibility
trn <- gen_friedman(500)
tst <- gen_friedman(10)

# Features only
X <- subset(trn, select = -y)
newX <- subset(tst, select = -y)

# Fit a default random forest
set.seed(2222) # for reproducibility
rfo <- ranger::ranger(y ~ ., data = trn)

# Prediction wrapper
pfun <- function(object, newdata) {
predict(object, data = newdata)$predictions
}

# Generate explanations for test set
set.seed(2024) # for reproducibility
ex1 <- explain(rfo, X = X, newdata = newX, pred_wrapper = pfun, adjust = TRUE,
nsim = 50)

# Same, but set `shap_only = FALSE` for convenience with shapviz
set.seed(2024) # for reproducibility
ex2 <- explain(rfo, X = X, newdata = newX, pred_wrapper = pfun, adjust = TRUE,
nsim = 50, shap_only = FALSE)

# Create "shapviz" objects
shv1 <- shapviz(ex1, X = newX)
shv2 <- shapviz(ex2)
shv3 <- shapviz(ex2$shapley_values, X = newX, baseline = ex2$baseline)

# Expectations
expect_error(shapviz(ex1))
expect_identical(ex2$baseline, mean(pfun(rfo, X)))
expect_identical(shv1$X, shv2$X)
expect_identical(shv1$X, shv3$X)
expect_identical(shv1$baseline, shv2$baseline)
expect_identical(shv1$baseline, shv3$baseline)

# # SHAP waterfall plots
# sv_waterfall(shv1, row_id = 1)
# sv_waterfall(shv2, row_id = 1)
# sv_waterfall(shv3, row_id = 1)
6 changes: 6 additions & 0 deletions man/explain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

102 changes: 0 additions & 102 deletions slowtests/test-shapviz.R

This file was deleted.

2 changes: 0 additions & 2 deletions vignettes/fastshap.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ To illustrate, we'll use the `explain()` function to estimate how each of jack f


```r
library(fastshap)

X <- subset(t1, select = -survived) # features only
set.seed(2113) # for reproducibility
(ex.jack <- explain(rfo, X = X, pred_wrapper = pfun, newdata = jack.dawson))
Expand Down
2 changes: 0 additions & 2 deletions vignettes/fastshap.Rmd.orig
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ To illustrate, we'll use the `explain()` function to estimate how each of jack f
[^2]: Note that we need to supply the training features via the `X` argument (i.e., no response column) and that `newdata` should also only contain columns of feature values.

```{r titanic-explain-jack}
library(fastshap)

X <- subset(t1, select = -survived) # features only
set.seed(2113) # for reproducibility
(ex.jack <- explain(rfo, X = X, pred_wrapper = pfun, newdata = jack.dawson))
Expand Down