Skip to content

Commit

Permalink
[R-package] ensure use of interaction_constraints does not lead to fe…
Browse files Browse the repository at this point in the history
…atures being ignored (#6377)
  • Loading branch information
mayer79 authored Jun 13, 2024
1 parent 1e7ebc5 commit 4e74403
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 63 deletions.
8 changes: 5 additions & 3 deletions .ci/lint-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

set -e -E -u -o pipefail

echo "running pre-commit checks"
pre-commit run --all-files || exit 1
echo "done running pre-commit checks"
# this can be re-enabled when this is fixed:
# https://github.com/tox-dev/filelock/issues/337
# echo "running pre-commit checks"
# pre-commit run --all-files || exit 1
# echo "done running pre-commit checks"

echo "running mypy"
mypy \
Expand Down
100 changes: 49 additions & 51 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,68 +59,66 @@

}

# [description]
#
# Besides applying checks, this function
#
# 1. turns feature *names* into 1-based integer positions, then
# 2. adds an extra list element with skipped features, then
# 3. turns 1-based integer positions into 0-based positions, and finally
# 4. collapses the values of each list element into a string like "[0, 1]".
#
.check_interaction_constraints <- function(interaction_constraints, column_names) {
if (is.null(interaction_constraints)) {
return(list())
}
if (!identical(class(interaction_constraints), "list")) {
stop("interaction_constraints must be a list")
}

# Convert interaction constraints to feature numbers
string_constraints <- list()
column_indices <- seq_along(column_names)

if (!is.null(interaction_constraints)) {
# Convert feature names to 1-based integer positions and apply checks
for (j in seq_along(interaction_constraints)) {
constraint <- interaction_constraints[[j]]

if (!methods::is(interaction_constraints, "list")) {
stop("interaction_constraints must be a list")
}
constraint_is_character_or_numeric <- sapply(
X = interaction_constraints
, FUN = function(x) {
return(is.character(x) || is.numeric(x))
}
)
if (!all(constraint_is_character_or_numeric)) {
stop("every element in interaction_constraints must be a character vector or numeric vector")
if (is.character(constraint)) {
constraint_indices <- match(constraint, column_names)
} else if (is.numeric(constraint)) {
constraint_indices <- as.integer(constraint)
} else {
stop("every element in interaction_constraints must be a character vector or numeric vector")
}

for (constraint in interaction_constraints) {

# Check for character name
if (is.character(constraint)) {

constraint_indices <- as.integer(match(constraint, column_names) - 1L)

# Provided indices, but some indices are not existing?
if (sum(is.na(constraint_indices)) > 0L) {
stop(
"supplied an unknown feature in interaction_constraints "
, sQuote(constraint[is.na(constraint_indices)])
)
}

} else {

# Check that constraint indices are at most number of features
if (max(constraint) > length(column_names)) {
stop(
"supplied a too large value in interaction_constraints: "
, max(constraint)
, " but only "
, length(column_names)
, " features"
)
}

# Store indices as [0, n-1] indexed instead of [1, n] indexed
constraint_indices <- as.integer(constraint - 1L)

}

# Convert constraint to string
constraint_string <- paste0("[", paste0(constraint_indices, collapse = ","), "]")
string_constraints <- append(string_constraints, constraint_string)
# Features outside range?
bad <- !(constraint_indices %in% column_indices)
if (any(bad)) {
stop(
"unknown feature(s) in interaction_constraints: "
, toString(sQuote(constraint[bad], q = "'"))
)
}

interaction_constraints[[j]] <- constraint_indices
}

return(string_constraints)
# Add missing features as new interaction set
remaining_indices <- setdiff(
column_indices, sort(unique(unlist(interaction_constraints)))
)
if (length(remaining_indices) > 0L) {
interaction_constraints <- c(
interaction_constraints, list(remaining_indices)
)
}

# Turn indices 0-based and convert to string
for (j in seq_along(interaction_constraints)) {
interaction_constraints[[j]] <- paste0(
"[", paste0(interaction_constraints[[j]] - 1L, collapse = ","), "]"
)
}
return(interaction_constraints)
}


Expand Down
45 changes: 37 additions & 8 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -2776,14 +2776,12 @@ test_that(paste0("lgb.train() throws an informative error if the members of inte
test_that("lgb.train() throws an informative error if interaction_constraints contains a too large index", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression",
interaction_constraints = list(c(1L, length(colnames(train$data)) + 1L), 3L))
expect_error({
bst <- lightgbm(
data = dtrain
, params = params
, nrounds = 2L
)
}, "supplied a too large value in interaction_constraints")
interaction_constraints = list(c(1L, ncol(train$data) + 1L:2L), 3L))
expect_error(
lightgbm(data = dtrain, params = params, nrounds = 2L)
, "unknown feature(s) in interaction_constraints: '127', '128'"
, fixed = TRUE
)
})

test_that(paste0("lgb.train() gives same result when interaction_constraints is specified as a list of ",
Expand Down Expand Up @@ -2876,6 +2874,37 @@ test_that(paste0("lgb.train() gives same results when using interaction_constrai

})

test_that("Interaction constraints add missing features correctly as new group", {
dtrain <- lgb.Dataset(
train$data[, 1L:6L] # Pick only some columns
, label = train$label
, params = list(num_threads = .LGB_MAX_THREADS)
)

list_of_constraints <- list(
list(3L, 1L:2L)
, list("cap-shape=convex", c("cap-shape=bell", "cap-shape=conical"))
)

for (constraints in list_of_constraints) {
params <- list(
objective = "regression"
, interaction_constraints = constraints
, verbose = .LGB_VERBOSITY
, num_threads = .LGB_MAX_THREADS
)
bst <- lightgbm(data = dtrain, params = params, nrounds = 10L)

expected_list <- list("[2]", "[0,1]", "[3,4,5]")
expect_equal(bst$params$interaction_constraints, expected_list)

expected_string <- "[interaction_constraints: [2],[0,1],[3,4,5]]"
expect_true(
grepl(expected_string, bst$save_model_to_string(), fixed = TRUE)
)
}
})

.generate_trainset_for_monotone_constraints_tests <- function(x3_to_categorical) {
n_samples <- 3000L
x1_positively_correlated_with_y <- runif(n = n_samples, min = 0.0, max = 1.0)
Expand Down
2 changes: 1 addition & 1 deletion R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ test_that("Loading a Booster from a text file works", {
, bagging_freq = 1L
, boost_from_average = FALSE
, categorical_feature = c(1L, 2L)
, interaction_constraints = list(c(1L, 2L), 1L)
, interaction_constraints = list(1L:2L, 3L, 4L:ncol(train$data))
, feature_contri = rep(0.5, ncol(train$data))
, metric = c("mape", "average_precision")
, learning_rate = 1.0
Expand Down
18 changes: 18 additions & 0 deletions R-package/tests/testthat/test_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,21 @@ test_that(".equal_or_both_null produces expected results", {
expect_false(.equal_or_both_null(10.0, 1L))
expect_true(.equal_or_both_null(0L, 0L))
})

test_that(".check_interaction_constraints() adds skipped features", {
ref <- letters[1L:5L]
ic_num <- list(1L, c(2L, 3L))
ic_char <- list("a", c("b", "c"))
expected <- list("[0]", "[1,2]", "[3,4]")

ic_checked_num <- .check_interaction_constraints(
interaction_constraints = ic_num, column_names = ref
)

ic_checked_char <- .check_interaction_constraints(
interaction_constraints = ic_char, column_names = ref
)

expect_equal(ic_checked_num, expected)
expect_equal(ic_checked_char, expected)
})

0 comments on commit 4e74403

Please sign in to comment.