From b1facf50502fb51a1e60c05b3ec83f68289df497 Mon Sep 17 00:00:00 2001 From: Zhiyuan He <362583303@qq.com> Date: Mon, 8 Nov 2021 10:06:50 +0800 Subject: [PATCH] Suppress categorical warning (fixes #3379) --- python-package/lightgbm/basic.py | 42 ++++++++++++++++++--- tests/python_package_test/test_utilities.py | 2 - 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 83a4b5c071da..812fd82a5d97 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1510,7 +1510,9 @@ def _lazy_init(self, data, label=None, reference=None, if categorical_indices: for cat_alias in _ConfigAliases.get("categorical_feature"): if cat_alias in params: - _log_warning(f'{cat_alias} in param dict is overridden.') + # If the params[cat_alias] is equal to categorical_indices, do not report the warning. + if not(isinstance(params[cat_alias], list) and set(params[cat_alias]) == categorical_indices): + _log_warning(f'{cat_alias} in param dict is overridden.') params.pop(cat_alias, None) params['categorical_column'] = sorted(categorical_indices) @@ -1765,6 +1767,32 @@ def __init_from_csc(self, csc, params_str, ref_dataset): ctypes.byref(self.handle))) return self + @staticmethod + def _compare_params_for_warning(params, other_params): + """Compare params. + + It is only for the warning purpose. Thus some keys are ignored. + + Returns + ------- + compare_result: bool + If they are equal, return True; Otherwise, return False. + """ + ignore_keys = _ConfigAliases.get("categorical_feature") + if params is None: + params = {} + if other_params is None: + other_params = {} + for k in other_params: + if k not in ignore_keys: + if k not in params or params[k] != other_params[k]: + return False + for k in params: + if k not in ignore_keys: + if k not in other_params or params[k] != other_params[k]: + return False + return True + def construct(self): """Lazy init. @@ -1776,8 +1804,10 @@ def construct(self): if self.handle is None: if self.reference is not None: reference_params = self.reference.get_params() - if self.get_params() != reference_params: - _log_warning('Overriding the parameters from Reference Dataset.') + params = self.get_params() + if params != reference_params: + if self._compare_params_for_warning(params, reference_params) is False: + _log_warning('Overriding the parameters from Reference Dataset.') self._update_params(reference_params) if self.used_indices is None: # create valid @@ -2062,11 +2092,11 @@ def set_categorical_feature(self, categorical_feature): self.categorical_feature = categorical_feature return self._free_handle() elif categorical_feature == 'auto': - _log_warning('Using categorical_feature in Dataset.') return self else: - _log_warning('categorical_feature in Dataset is overridden.\n' - f'New categorical_feature is {sorted(list(categorical_feature))}') + if self.categorical_feature != 'auto': + _log_warning('categorical_feature in Dataset is overridden.\n' + f'New categorical_feature is {sorted(list(categorical_feature))}') self.categorical_feature = categorical_feature return self._free_handle() else: diff --git a/tests/python_package_test/test_utilities.py b/tests/python_package_test/test_utilities.py index 02da77fc56e6..be57d585f695 100644 --- a/tests/python_package_test/test_utilities.py +++ b/tests/python_package_test/test_utilities.py @@ -43,8 +43,6 @@ def dummy_metric(_, __): lgb.plot_metric(eval_records) expected_log = r""" -WARNING | categorical_feature in Dataset is overridden. -New categorical_feature is [1] INFO | [LightGBM] [Warning] There are no meaningful features, as all feature values are constant. INFO | [LightGBM] [Info] Number of positive: 2, number of negative: 2 INFO | [LightGBM] [Info] Total Bins 0