Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] support customizing Dataset creation in Booster.refit() (fixes #3038) #4894

Merged
merged 12 commits into from
Jan 22, 2022
61 changes: 59 additions & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3503,7 +3503,21 @@ def predict(self, data, start_iteration=0, num_iteration=None,
raw_score, pred_leaf, pred_contrib,
data_has_header, is_reshape)

def refit(self, data, label, decay_rate=0.9, **kwargs):
def refit(
self,
data,
label,
decay_rate=0.9,
reference=None,
weight=None,
group=None,
init_score=None,
feature_name='auto',
categorical_feature='auto',
dataset_params=None,
free_raw_data=True,
**kwargs
):
"""Refit the existing Booster by new data.

Parameters
Expand All @@ -3516,6 +3530,35 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
decay_rate : float, optional (default=0.9)
Decay rate of refit,
will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees.
reference : Dataset or None, optional (default=None)
Reference for ``data``.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
Weight for each ``data`` instance. Weight should be non-negative values because the Hessian
value multiplied by weight is supposed to be non-negative.
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
Group/query size for ``data``.
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
Only used in the learning-to-rank task.
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None, optional (default=None)
Init score for ``data``.
feature_name : list of str or 'auto', optional (default="auto")
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
Feature names for ``data``.
If 'auto' and data is pandas DataFrame, data columns names are used.
categorical_feature : list of str or int, or 'auto', optional (default="auto")
Categorical features for ``data``.
If list of int, interpreted as indices.
If list of str, interpreted as feature names (need to specify ``feature_name`` as well).
If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
dataset_params : dict or None, optional (default=None)
Other parameters for Dataset ``data``.
free_raw_data : bool, optional (default=True)
If True, raw data is freed after constructing inner Dataset for ``data``.
**kwargs
Other parameters for refit.
These parameters will be passed to ``predict`` method.
Expand All @@ -3527,6 +3570,8 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
"""
if self.__set_objective_to_none:
raise LightGBMError('Cannot refit due to null objective function.')
if dataset_params is None:
dataset_params = {}
predictor = self._to_predictor(deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow, ncol = leaf_preds.shape
Expand All @@ -3540,7 +3585,19 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
default_value=None
)
new_params["linear_tree"] = bool(out_is_linear.value)
train_set = Dataset(data, label, params=new_params)
new_params.update(dataset_params)
train_set = Dataset(
data=data,
label=label,
reference=reference,
weight=weight,
group=group,
init_score=init_score,
feature_name=feature_name,
categorical_feature=categorical_feature,
params=new_params,
free_raw_data=free_raw_data,
)
new_params['refit_decay_rate'] = decay_rate
new_booster = Booster(new_params, train_set)
# Copy models
Expand Down
34 changes: 34 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,40 @@ def test_refit():
assert err_pred > new_err_pred


def test_refit_dataset_params():
# check refit accepts dataset_params
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
lgb_train = lgb.Dataset(X_train, y_train)
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
train_params = {
'objective': 'binary',
'verbose': -1,
'seed': 123
}
gbm = lgb.train(train_params, lgb_train, num_boost_round=10)
non_weight_err_pred = log_loss(y_test, gbm.predict(X_test))
refit_weight = np.random.rand(y_train.shape[0])
dataset_params = {
'max_bin': 260,
'min_data_in_bin': 5,
'data_random_seed': 123,
}
new_gbm = gbm.refit(
data=X_train,
label=y_train,
weight=refit_weight,
dataset_params=dataset_params,
)
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
weight_err_pred = log_loss(y_test, new_gbm.predict(X_test))
train_set_params = new_gbm.train_set.get_params()
stored_weights = new_gbm.train_set.get_weight()
assert weight_err_pred != non_weight_err_pred
assert train_set_params["max_bin"] == 260
assert train_set_params["min_data_in_bin"] == 5
assert train_set_params["data_random_seed"] == 123
np.testing.assert_allclose(stored_weights, refit_weight)

jameslamb marked this conversation as resolved.
Show resolved Hide resolved

def test_mape_rf():
X, y = load_boston(return_X_y=True)
params = {
Expand Down