Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Add print() and summary() methods for Booster #4686

Merged
merged 18 commits into from
Nov 13, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ S3method(dimnames,lgb.Dataset)
S3method(get_field,lgb.Dataset)
S3method(getinfo,lgb.Dataset)
S3method(predict,lgb.Booster)
S3method(print,lgb.Booster)
S3method(set_field,lgb.Dataset)
S3method(setinfo,lgb.Dataset)
S3method(slice,lgb.Dataset)
S3method(summary,lgb.Booster)
export(get_field)
export(getinfo)
export(lgb.Dataset)
Expand Down
48 changes: 48 additions & 0 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,54 @@ predict.lgb.Booster <- function(object,
)
}

#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
handle <- x$.__enclos_env__$private$handle
handle_is_null <- lgb.is.null.handle(handle)

if (!handle_is_null) {
cat(sprintf("LightGBM Model (%d trees)\n", x$current_iter()))
} else {
cat("LightGBM Model\n")
}

if (!handle_is_null) {
if (x$.__enclos_env__$private$num_class == 1L) {
cat(sprintf("Objective: %s\n", x$params$objective))
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
} else {
cat(sprintf("Objective: %s (%d classes)\n"
, x$params$objective
, x$.__enclos_env__$private$num_class == 1L))
david-cortes marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
cat("(Booster handle is invalid)\n")
}

if (!handle_is_null) {
ncols <- .Call(LGBM_BoosterGetNumFeatures_R, handle)
cat(sprintf("Fitted to dataset with %d columns\n", ncols))
}

return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
print(object)
}

#' @name lgb.load
#' @title Load LightGBM model
#' @description Load LightGBM takes in either a file path or model string.
Expand Down
19 changes: 19 additions & 0 deletions R-package/man/print.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions R-package/man/summary.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,15 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
R_API_END();
}

SEXP LGBM_BoosterGetNumFeatures_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int out = 0;
CHECK_CALL(LGBM_BoosterGetNumFeatures(R_ExternalPtrAddr(handle), &out));
return Rf_ScalarInteger(out);
R_API_END();
}

SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
Expand Down Expand Up @@ -889,6 +898,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterResetTrainingData_R" , (DL_FUNC) &LGBM_BoosterResetTrainingData_R , 2},
{"LGBM_BoosterResetParameter_R" , (DL_FUNC) &LGBM_BoosterResetParameter_R , 2},
{"LGBM_BoosterGetNumClasses_R" , (DL_FUNC) &LGBM_BoosterGetNumClasses_R , 2},
{"LGBM_BoosterGetNumFeatures_R" , (DL_FUNC) &LGBM_BoosterGetNumFeatures_R , 1},
{"LGBM_BoosterUpdateOneIter_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R , 1},
{"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
{"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1},
Expand Down
9 changes: 9 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R(
SEXP out
);

/*!
* \brief Get number of features
david-cortes marked this conversation as resolved.
Show resolved Hide resolved
* \param handle Booster handle
* \return R integer
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumFeatures_R(
SEXP handle
);

/*!
* \brief update the model in one round
* \param handle Booster handle
Expand Down
9 changes: 9 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,15 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int leaf_idx,
double val);

/*!
* \brief Get number of features (columns) to which a booster was fit.
* \param handle Handle of booster
* \param[out] out_val Output result from the specified leaf
* \return 0 when succeed, -1 when failure happens
david-cortes marked this conversation as resolved.
Show resolved Hide resolved
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumFeatures(BoosterHandle handle,
int *out_val);

/*!
* \brief Get model feature importance.
* \param handle Handle of booster
Expand Down
7 changes: 7 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2308,6 +2308,13 @@ int LGBM_BoosterSetLeafValue(BoosterHandle handle,
API_END();
}

int LGBM_BoosterGetNumFeatures(BoosterHandle handle, int *out_val) {
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_val = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
API_END();
}

int LGBM_BoosterFeatureImportance(BoosterHandle handle,
int num_iteration,
int importance_type,
Expand Down