Skip to content

Commit

Permalink
[python] add type hints for custom objective and metric functions in …
Browse files Browse the repository at this point in the history
…scikit-learn interface (#4547)

* [python] add type hints for custom objective and metric functions in scikit-learn interface

* update type hints

* remote unnecessary input

* Update python-package/lightgbm/sklearn.py

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* remove type hint on objective being callable

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
jameslamb and StrikerRUS authored Nov 15, 2021
1 parent bfb346c commit 843d380
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
22 changes: 11 additions & 11 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from copy import deepcopy
from enum import Enum, auto
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
from urllib.parse import urlparse

import numpy as np
Expand All @@ -21,8 +21,8 @@
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note,
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction,
_lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict)

_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
Expand Down Expand Up @@ -400,7 +400,7 @@ def _train(
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None,
**kwargs: Any
) -> LGBMModel:
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def _lgb_dask_fit(
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
Expand Down Expand Up @@ -1096,7 +1096,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1165,7 +1165,7 @@ def fit(
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "DaskLGBMClassifier":
Expand Down Expand Up @@ -1281,7 +1281,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1348,7 +1348,7 @@ def fit(
eval_names: Optional[List[str]] = None,
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "DaskLGBMRegressor":
Expand Down Expand Up @@ -1446,7 +1446,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1516,7 +1516,7 @@ def fit(
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Iterable[int] = (1, 2, 3, 4, 5),
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
Expand Down
38 changes: 33 additions & 5 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Scikit-learn wrapper interface for LightGBM."""
import copy
from inspect import signature
from typing import Callable, Dict, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np

Expand All @@ -11,14 +11,42 @@
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable,
pd_DataFrame)
pd_DataFrame, pd_Series)
from .engine import train

_ArrayLike = Union[List, np.ndarray, pd_Series]
_EvalResultType = Tuple[str, float, bool]

_LGBM_ScikitCustomObjectiveFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Tuple[_ArrayLike, _ArrayLike]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Tuple[_ArrayLike, _ArrayLike]
],
]
_LGBM_ScikitCustomEvalFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
]


class _ObjectiveFunctionWrapper:
"""Proxy class for objective function."""

def __init__(self, func):
def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):
"""Construct a proxy class.
This class transforms objective function to match objective function with signature ``new_func(preds, dataset)``
Expand Down Expand Up @@ -107,7 +135,7 @@ def __call__(self, preds, dataset):
class _EvalFunctionWrapper:
"""Proxy class for evaluation function."""

def __init__(self, func):
def __init__(self, func: _LGBM_ScikitCustomEvalFunction):
"""Construct a proxy class.
This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)``
Expand Down Expand Up @@ -358,7 +386,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[str, Callable]] = None,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[Dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down

0 comments on commit 843d380

Please sign in to comment.