Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] introduce Dataset methods set_field() and get_field() #4571

Merged
merged 11 commits into from
Sep 25, 2021
2 changes: 1 addition & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ Imports:
utils
SystemRequirements:
C++11
RoxygenNote: 7.1.1
RoxygenNote: 7.1.2
4 changes: 4 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
S3method("dimnames<-",lgb.Dataset)
S3method(dim,lgb.Dataset)
S3method(dimnames,lgb.Dataset)
S3method(get_field,lgb.Dataset)
S3method(getinfo,lgb.Dataset)
S3method(predict,lgb.Booster)
S3method(set_field,lgb.Dataset)
S3method(setinfo,lgb.Dataset)
S3method(slice,lgb.Dataset)
export(get_field)
export(getinfo)
export(lgb.Dataset)
export(lgb.Dataset.construct)
Expand All @@ -30,6 +33,7 @@ export(lgb.unloader)
export(lightgbm)
export(readRDS.lgb.Booster)
export(saveRDS.lgb.Booster)
export(set_field)
export(setinfo)
export(slice)
import(methods)
Expand Down
187 changes: 159 additions & 28 deletions R-package/R/lgb.Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,17 @@ Dataset <- R6::R6Class(
for (i in seq_along(private$info)) {

p <- private$info[i]
self$setinfo(name = names(p), info = p[[1L]])
self$set_field(
field_name = names(p)
, data = p[[1L]]
)

}

}

# Get label information existence
if (is.null(self$getinfo(name = "label"))) {
if (is.null(self$get_field(field_name = "label"))) {
stop("lgb.Dataset.construct: label should be set")
}

Expand Down Expand Up @@ -452,27 +455,41 @@ Dataset <- R6::R6Class(

},

# Get information
getinfo = function(name) {
warning(paste0(
"Dataset$getinfo() is deprecated and will be removed in a future release. "
, "Use Dataset$get_field() instead."
))
return(
self$get_field(
field_name = name
)
)
},

get_field = function(field_name) {

# Check if attribute key is in the known attribute list
if (!is.character(name) || length(name) != 1L || !name %in% .INFO_KEYS()) {
stop("getinfo: name must one of the following: ", paste0(sQuote(.INFO_KEYS()), collapse = ", "))
if (!is.character(field_name) || length(field_name) != 1L || !field_name %in% .INFO_KEYS()) {
stop(
"Dataset$get_field(): field_name must one of the following: "
, paste0(sQuote(.INFO_KEYS()), collapse = ", ")
)
}

# Check for info name and handle
if (is.null(private$info[[name]])) {
if (is.null(private$info[[field_name]])) {

if (lgb.is.null.handle(x = private$handle)) {
stop("Cannot perform getinfo before constructing Dataset.")
stop("Cannot perform Dataset$get_field() before constructing Dataset.")
}

# Get field size of info
info_len <- 0L
.Call(
LGBM_DatasetGetFieldSize_R
, private$handle
, name
, field_name
, info_len
)

Expand All @@ -481,7 +498,7 @@ Dataset <- R6::R6Class(

# Get back fields
ret <- NULL
ret <- if (name == "group") {
ret <- if (field_name == "group") {
integer(info_len) # Integer
} else {
numeric(info_len) # Numeric
Expand All @@ -490,47 +507,62 @@ Dataset <- R6::R6Class(
.Call(
LGBM_DatasetGetField_R
, private$handle
, name
, field_name
, ret
)

private$info[[name]] <- ret
private$info[[field_name]] <- ret

}
}

return(private$info[[name]])
return(private$info[[field_name]])

},

# Set information
setinfo = function(name, info) {
warning(paste0(
"Dataset$setinfo() is deprecated and will be removed in a future release. "
, "Use Dataset$set_field() instead."
))
return(
self$set_field(
field_name = name
, data = info
)
)
},

set_field = function(field_name, data) {

# Check if attribute key is in the known attribute list
if (!is.character(name) || length(name) != 1L || !name %in% .INFO_KEYS()) {
stop("setinfo: name must one of the following: ", paste0(sQuote(.INFO_KEYS()), collapse = ", "))
if (!is.character(field_name) || length(field_name) != 1L || !field_name %in% .INFO_KEYS()) {
stop(
"Dataset$set_field(): field_name must one of the following: "
, paste0(sQuote(.INFO_KEYS()), collapse = ", ")
)
}

# Check for type of information
info <- if (name == "group") {
as.integer(info) # Integer
data <- if (field_name == "group") {
as.integer(data) # Integer
} else {
as.numeric(info) # Numeric
as.numeric(data) # Numeric
}

# Store information privately
private$info[[name]] <- info
private$info[[field_name]] <- data

if (!lgb.is.null.handle(x = private$handle) && !is.null(info)) {
if (!lgb.is.null.handle(x = private$handle) && !is.null(data)) {

if (length(info) > 0L) {
if (length(data) > 0L) {

.Call(
LGBM_DatasetSetField_R
, private$handle
, name
, info
, length(info)
, field_name
, data
, length(data)
)

private$version <- private$version + 1L
Expand All @@ -554,7 +586,7 @@ Dataset <- R6::R6Class(
, paste(names(additional_keyword_args), collapse = ", ")
, ". These are ignored and should be removed. "
, "To change the parameters of a Dataset produced by Dataset$slice(), use Dataset$set_params(). "
, "To modify attributes like 'init_score', use Dataset$setinfo(). "
, "To modify attributes like 'init_score', use Dataset$set_field(). "
, "In future releases of lightgbm, this warning will become an error."
))
}
Expand Down Expand Up @@ -1110,7 +1142,7 @@ dimnames.lgb.Dataset <- function(x) {
#'
#' dsub <- lightgbm::slice(dtrain, seq_len(42L))
#' lgb.Dataset.construct(dsub)
#' labels <- lightgbm::getinfo(dsub, "label")
#' labels <- lightgbm::get_field(dsub, "label")
#' }
#' @export
slice <- function(dataset, ...) {
Expand Down Expand Up @@ -1173,6 +1205,8 @@ getinfo <- function(dataset, ...) {
#' @export
getinfo.lgb.Dataset <- function(dataset, name, ...) {

warning("Calling getinfo() on a lgb.Dataset is deprecated. Use get_field() instead.")

additional_args <- list(...)
if (length(additional_args) > 0L) {
warning(paste0(
Expand All @@ -1187,7 +1221,7 @@ getinfo.lgb.Dataset <- function(dataset, name, ...) {
stop("getinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
}

return(dataset$getinfo(name = name))
return(dataset$get_field(field_name = name))

}

Expand Down Expand Up @@ -1236,6 +1270,8 @@ setinfo <- function(dataset, ...) {
#' @export
setinfo.lgb.Dataset <- function(dataset, name, info, ...) {

warning("Calling setinfo() on a lgb.Dataset is deprecated. Use set_field() instead.")

additional_args <- list(...)
if (length(additional_args) > 0L) {
warning(paste0(
Expand All @@ -1250,7 +1286,102 @@ setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
stop("setinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
}

return(invisible(dataset$setinfo(name = name, info = info)))
return(invisible(dataset$set_field(field_name = name, data = info)))
}

#' @name get_field
#' @title Get one attribute of a \code{lgb.Dataset}
#' @description Get one attribute of a \code{lgb.Dataset}
#' @param dataset Object of class \code{lgb.Dataset}
#' @param field_name String with the name of the attribute to get. One of the following.
#' \itemize{
#' \item \code{label}: label lightgbm learns from ;
#' \item \code{weight}: to do a weight rescale ;
#' \item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to
#' group rows together as ordered results from the same set of candidate results to be ranked.
#' For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)},
#' that means that you have 6 groups, where the first 10 records are in the first group,
#' records 11-30 are in the second group, etc.}
#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from.
#' }
#' @return requested attribute
#'
#' @examples
#' \donttest{
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain)
#'
#' labels <- lightgbm::get_field(dtrain, "label")
#' lightgbm::set_field(dtrain, "label", 1 - labels)
#'
#' labels2 <- lightgbm::get_field(dtrain, "label")
#' stopifnot(all(labels2 == 1 - labels))
#' }
#' @export
get_field <- function(dataset, field_name) {
UseMethod("get_field")
}

#' @rdname get_field
#' @export
get_field.lgb.Dataset <- function(dataset, field_name) {

# Check if dataset is not a dataset
if (!lgb.is.Dataset(x = dataset)) {
stop("get_field.lgb.Dataset(): input dataset should be an lgb.Dataset object")
}

return(dataset$get_field(field_name = field_name))

}

#' @name set_field
#' @title Set one attribute of a \code{lgb.Dataset} object
#' @description Set one attribute of a \code{lgb.Dataset}
#' @param dataset Object of class \code{lgb.Dataset}
#' @param field_name String with the name of the attribute to set. One of the following.
#' \itemize{
#' \item \code{label}: label lightgbm learns from ;
#' \item \code{weight}: to do a weight rescale ;
#' \item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to
#' group rows together as ordered results from the same set of candidate results to be ranked.
#' For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)},
#' that means that you have 6 groups, where the first 10 records are in the first group,
#' records 11-30 are in the second group, etc.}
#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from.
#' }
#' @param data The data for the field. See examples.
#' @return The \code{lgb.Dataset} you passed in.
#'
#' @examples
#' \donttest{
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain)
#'
#' labels <- lightgbm::get_field(dtrain, "label")
#' lightgbm::set_field(dtrain, "label", 1 - labels)
#'
#' labels2 <- lightgbm::get_field(dtrain, "label")
#' stopifnot(all.equal(labels2, 1 - labels))
#' }
#' @export
set_field <- function(dataset, field_name, data) {
UseMethod("set_field")
}

#' @rdname set_field
#' @export
set_field.lgb.Dataset <- function(dataset, field_name, data) {

if (!lgb.is.Dataset(x = dataset)) {
stop("set_field.lgb.Dataset: input dataset should be an lgb.Dataset object")
}

return(invisible(dataset$set_field(field_name = field_name, data = data)))
}

#' @name lgb.Dataset.set.categorical
Expand Down
Loading