Skip to content

Commit

Permalink
fixed n_trees bug - now works with 1.7 version of xgboost
Browse files Browse the repository at this point in the history
  • Loading branch information
David Foster committed Jun 18, 2018
1 parent bfcf902 commit 10e61bb
Show file tree
Hide file tree
Showing 9 changed files with 14 additions and 17 deletions.
Binary file modified .DS_Store
Binary file not shown.
8 changes: 4 additions & 4 deletions R/buildExplainer.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' @param trainingData A DMatrix of data used to train the model
#' @param type The objective function of the model - either "binary" (for binary:logistic) or "regression" (for reg:linear)
#' @param base_score Default 0.5. The base_score variable of the xgboost model.
#' @param n_first_tree Default NULL. The number of trees to include in the model.
#' @param trees_idx Default NULL. An integer vector of tree indices that should be parsed. If set to NULL, all trees of the model are parsed.
#' @return The XGBoost Explainer for the model. This is a data table where each row is a leaf of a tree in the xgboost model
#' and each column is the impact of each feature on the prediction at the leaf.
#'
Expand Down Expand Up @@ -47,18 +47,18 @@
#' trees = xgb.model.dt.tree(col_names, model = xgb.model)
#'
#' #### The XGBoost Explainer
#' explainer = buildExplainer(xgb.model,xgb.train.data, type="binary", base_score = 0.5, n_first_tree = xgb.model$best_ntreelimit - 1)
#' explainer = buildExplainer(xgb.model,xgb.train.data, type="binary", base_score = 0.5, trees = NULL)
#' pred.breakdown = explainPredictions(xgb.model, explainer, xgb.test.data)
#'
#' showWaterfall(xgb.model, explainer, xgb.test.data, test.data, 2, type = "binary")
#' showWaterfall(xgb.model, explainer, xgb.test.data, test.data, 8, type = "binary")


buildExplainer = function(xgb.model, trainingData, type = "binary", base_score = 0.5, n_first_tree = NULL){
buildExplainer = function(xgb.model, trainingData, type = "binary", base_score = 0.5, trees_idx = NULL){

col_names = attr(trainingData, ".Dimnames")[[2]]
cat('\nCreating the trees of the xgboost model...')
trees = xgb.model.dt.tree(col_names, model = xgb.model, n_first_tree = n_first_tree)
trees = xgb.model.dt.tree(col_names, model = xgb.model, trees = trees_idx)
cat('\nGetting the leaf nodes for the training set observations...')
nodes.train = predict(xgb.model,trainingData,predleaf =TRUE)

Expand Down
2 changes: 1 addition & 1 deletion R/explainPredictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#' trees = xgb.model.dt.tree(col_names, model = xgb.model)
#'
#' #### The XGBoost Explainer
#' explainer = buildExplainer(xgb.model,xgb.train.data, type="binary", base_score = 0.5, n_first_tree = xgb.model$best_ntreelimit - 1)
#' explainer = buildExplainer(xgb.model,xgb.train.data, type="binary", base_score = 0.5, trees = NULL)
#' pred.breakdown = explainPredictions(xgb.model, explainer, xgb.test.data)
#'
#' showWaterfall(xgb.model, explainer, xgb.test.data, test.data, 2, type = "binary")
Expand Down
8 changes: 3 additions & 5 deletions R/getTreeBreakdown.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
#' @import data.table
#' @import xgboost
getTreeBreakdown = function(tree, col_names){

####accepts a tree (data table), and column names
####outputs a data table, of the impact of each variable + intercept, for each leaf




tree_breakdown <- vector("list", length(col_names) + 2)
names(tree_breakdown) = c(col_names,'intercept','leaf')

Expand All @@ -21,4 +19,4 @@ getTreeBreakdown = function(tree, col_names){
}

return (tree_breakdown)
}
}
3 changes: 1 addition & 2 deletions R/showWaterfall.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
#' trees = xgb.model.dt.tree(col_names, model = xgb.model)
#'
#' #### The XGBoost Explainer
#' explainer = buildExplainer(xgb.model,xgb.train.data, type="binary", base_score = 0.5, n_first_tree = xgb.model$best_ntreelimit - 1)
#' explainer = buildExplainer(xgb.model,xgb.train.data, type="binary", base_score = 0.5, trees = NULL)
#' pred.breakdown = explainPredictions(xgb.model, explainer, xgb.test.data)
#'
#' showWaterfall(xgb.model, explainer, xgb.test.data, test.data, 2, type = "binary")
Expand Down Expand Up @@ -133,7 +133,6 @@ showWaterfall = function(xgb.model, explainer, DMatrix, data.matrix, idx, type =

ybreaks<-logit(seq(2,98,2)/100)

cat(ybreaks)
waterfalls::waterfall(values = breakdown_summary,
rect_text_labels = round(breakdown_summary, 2),
labels = labels,
Expand Down
Binary file added installed_old.rda
Binary file not shown.
6 changes: 3 additions & 3 deletions man/buildExplainer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/explainPredictions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/showWaterfall.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 10e61bb

Please sign in to comment.