Skip to content

Commit

Permalink
[python-package] fix mypy errors in plotting.py (#4838)
Browse files Browse the repository at this point in the history
* [python-package] fix mypy errors in plotting.py

* empty commit
  • Loading branch information
jameslamb authored Dec 2, 2021
1 parent 8f4126d commit f8bab7f
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions python-package/lightgbm/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,12 @@ def plot_split_value_histogram(
elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.')

hist, bins = booster.get_split_value_histogram(feature=feature, bins=bins, xgboost_style=False)
hist, split_bins = booster.get_split_value_histogram(feature=feature, bins=bins, xgboost_style=False)
if np.count_nonzero(hist) == 0:
raise ValueError('Cannot plot split value histogram, '
f'because feature {feature} was not used in splitting')
width = width_coef * (bins[1] - bins[0])
centred = (bins[:-1] + bins[1:]) / 2
width = width_coef * (split_bins[1] - split_bins[0])
centred = (split_bins[:-1] + split_bins[1:]) / 2

if ax is None:
if figsize is not None:
Expand All @@ -253,8 +253,8 @@ def plot_split_value_histogram(
if xlim is not None:
_check_not_tuple_of_2_elements(xlim, 'xlim')
else:
range_result = bins[-1] - bins[0]
xlim = (bins[0] - range_result * 0.2, bins[-1] + range_result * 0.2)
range_result = split_bins[-1] - split_bins[0]
xlim = (split_bins[0] - range_result * 0.2, split_bins[-1] + range_result * 0.2)
ax.set_xlim(xlim)

ax.yaxis.set_major_locator(MaxNLocator(integer=True))
Expand Down Expand Up @@ -358,13 +358,13 @@ def plot_metric(
_, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)

if dataset_names is None:
dataset_names = iter(eval_results.keys())
dataset_names_iter = iter(eval_results.keys())
elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names:
raise ValueError('dataset_names should be iterable and cannot be empty')
else:
dataset_names = iter(dataset_names)
dataset_names_iter = iter(dataset_names)

name = next(dataset_names) # take one as sample
name = next(dataset_names_iter) # take one as sample
metrics_for_one = eval_results[name]
num_metric = len(metrics_for_one)
if metric is None:
Expand All @@ -381,7 +381,7 @@ def plot_metric(
x_ = range(num_iteration)
ax.plot(x_, results, label=name)

for name in dataset_names:
for name in dataset_names_iter:
metrics_for_one = eval_results[name]
results = metrics_for_one[metric]
max_result = max(max(results), max_result)
Expand Down

0 comments on commit f8bab7f

Please sign in to comment.