Skip to content

Commit

Permalink
MAINT: lint, remove extra files
Browse files Browse the repository at this point in the history
  • Loading branch information
stsievert committed Jun 7, 2019
1 parent cd38d7c commit e5c8c60
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 277 deletions.
19 changes: 10 additions & 9 deletions dask_ml/model_selection/_hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from warnings import warn

import numpy as np
from sklearn.metrics.scorer import check_scoring
from sklearn.utils import check_random_state
from tornado import gen

from ._incremental import BaseIncrementalSearchCV
from ._successive_halving import SuccessiveHalvingSearchCV, _get_max_iter
from ._successive_halving import SuccessiveHalvingSearchCV

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -471,13 +470,15 @@ def _get_meta(hists, brackets, SHAs, key):
decisions = {hi["partial_fit_calls"] for h in hist.values() for hi in h}
if bracket != max(brackets):
decisions.discard(1)
meta_.append({
"decisions": sorted(list(decisions)),
"n_models": len(hist),
"bracket": bracket,
"partial_fit_calls": sum(calls.values()),
"SuccessiveHalvingSearchCV params": _get_SHA_params(SHAs[bracket]),
})
meta_.append(
{
"decisions": sorted(list(decisions)),
"n_models": len(hist),
"bracket": bracket,
"partial_fit_calls": sum(calls.values()),
"SuccessiveHalvingSearchCV params": _get_SHA_params(SHAs[bracket]),
}
)
meta_ = sorted(meta_, key=lambda x: x["bracket"])
return meta_, history_

Expand Down
63 changes: 0 additions & 63 deletions tests/model_selection/test_dist_hyperband.py

This file was deleted.

6 changes: 1 addition & 5 deletions tests/model_selection/test_hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,7 @@ def _test_mirrors_paper(c, s, a, b):
assert alg.metadata == alg.metadata_

assert isinstance(alg.metadata["brackets"], list)
assert set(alg.metadata.keys()) == {
"n_models",
"partial_fit_calls",
"brackets",
}
assert set(alg.metadata.keys()) == {"n_models", "partial_fit_calls", "brackets"}
for bracket in alg.metadata["brackets"]:
assert set(bracket.keys()) == {
"n_models",
Expand Down
198 changes: 0 additions & 198 deletions tests/model_selection/test_split.py

This file was deleted.

5 changes: 3 additions & 2 deletions tests/model_selection/test_successive_halving.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import numpy as np
import pytest
from distributed.utils_test import gen_cluster # noqa: F401
from sklearn.datasets import make_classification
from sklearn.linear_model import SGDClassifier
import pytest

from dask_ml.model_selection import SuccessiveHalvingSearchCV
from dask_ml.model_selection._successive_halving import (
_get_n_initial_calls,
_get_max_iter,
_get_n_initial_calls,
)


Expand Down Expand Up @@ -40,6 +40,7 @@ def test_sha_max_iter(n, r):
(so only one model is obtained at the end, as per the last assert)
"""

@gen_cluster(client=True)
def _test_sha_max_iter(c, s, a, b):
model = SGDClassifier(tol=1e-3)
Expand Down

0 comments on commit e5c8c60

Please sign in to comment.