Skip to content

Commit

Permalink
Gather CV Results as Completed (#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Vecchio authored and TomAugspurger committed May 20, 2019
1 parent 3b6ea18 commit 954d41f
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 81 deletions.
227 changes: 156 additions & 71 deletions dask_ml/model_selection/_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import, division, print_function

import logging
import numbers
from collections import defaultdict
from itertools import repeat
Expand All @@ -11,6 +12,7 @@
import packaging.version
from dask.base import tokenize
from dask.delayed import delayed
from dask.distributed import as_completed
from dask.utils import derived_from
from sklearn import model_selection
from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone, is_classifier
Expand Down Expand Up @@ -42,19 +44,19 @@
cv_extract_params,
cv_n_samples,
cv_split,
decompress_params,
feature_union,
feature_union_concat,
fit,
fit_and_score,
fit_best,
fit_transform,
get_best_params,
pipeline,
score,
)
from .utils import DeprecationDict, is_dask_collection, to_indexable, to_keys, unzip

logger = logging.getLogger(__name__)

try:
from cytoolz import get, pluck
except ImportError: # pragma: no cover
Expand All @@ -63,7 +65,6 @@

__all__ = ["GridSearchCV", "RandomizedSearchCV"]


if SK_VERSION <= packaging.version.parse("0.21.dev0"):

_RETURN_TRAIN_SCORE_DEFAULT = "warn"
Expand Down Expand Up @@ -102,6 +103,19 @@ def __call__(self, est):
return self.token if c == 0 else self.token + str(c)


def map_fit_params(dsk, fit_params):
if fit_params:
# A mapping of {name: (name, graph-key)}
param_values = to_indexable(*fit_params.values(), allow_scalars=True)
fit_params = {
k: (k, v) for (k, v) in zip(fit_params, to_keys(dsk, *param_values))
}
else:
fit_params = {}

return fit_params


def build_graph(
estimator,
cv,
Expand All @@ -118,7 +132,68 @@ def build_graph(
cache_cv=True,
multimetric=False,
):
# This is provided for compatibility with TPOT. Remove
# once TPOT is updated and requires a dask-ml>=0.13.0
def decompress_params(fields, params):
return [{k: v for k, v in zip(fields, p) if v is not MISSING} for p in params]

fields, tokens, params = normalize_params(candidate_params)
dsk, keys, n_splits, main_token = build_cv_graph(
estimator,
cv,
scorer,
candidate_params,
X,
y=y,
groups=groups,
fit_params=fit_params,
iid=iid,
error_score=error_score,
return_train_score=return_train_score,
cache_cv=cache_cv,
)
cv_name = "cv-split-" + main_token
if iid:
weights = "cv-n-samples-" + main_token
dsk[weights] = (cv_n_samples, cv_name)
scores = keys[1:]
else:
scores = keys

cv_results = "cv-results-" + main_token
candidate_params_name = "cv-parameters-" + main_token
dsk[candidate_params_name] = (decompress_params, fields, params)
if multimetric:
metrics = list(scorer.keys())
else:
metrics = None
dsk[cv_results] = (
create_cv_results,
scores,
candidate_params_name,
n_splits,
error_score,
weights,
metrics,
)
keys = [cv_results]
return dsk, keys, n_splits


def build_cv_graph(
estimator,
cv,
scorer,
candidate_params,
X,
y=None,
groups=None,
fit_params=None,
iid=True,
error_score="raise",
return_train_score=_RETURN_TRAIN_SCORE_DEFAULT,
cache_cv=True,
):
X, y, groups = to_indexable(X, y, groups)
cv = check_cv(cv, y, is_classifier(estimator))
# "pairwise" estimators require a different graph for CV splitting
Expand All @@ -128,14 +203,7 @@ def build_graph(
X_name, y_name, groups_name = to_keys(dsk, X, y, groups)
n_splits = compute_n_splits(cv, X, y, groups)

if fit_params:
# A mapping of {name: (name, graph-key)}
param_values = to_indexable(*fit_params.values(), allow_scalars=True)
fit_params = {
k: (k, v) for (k, v) in zip(fit_params, to_keys(dsk, *param_values))
}
else:
fit_params = {}
fit_params = map_fit_params(dsk, fit_params)

fields, tokens, params = normalize_params(candidate_params)
main_token = tokenize(
Expand Down Expand Up @@ -176,50 +244,33 @@ def build_graph(
scorer,
return_train_score,
)
keys = [weights] + scores if weights else scores
return dsk, keys, n_splits, main_token

cv_results = "cv-results-" + main_token
candidate_params_name = "cv-parameters-" + main_token
dsk[candidate_params_name] = (decompress_params, fields, params)
if multimetric:
metrics = list(scorer.keys())
else:
metrics = None
dsk[cv_results] = (
create_cv_results,
scores,
candidate_params_name,
n_splits,
error_score,
weights,
metrics,
)
keys = [cv_results]

if refit:
if multimetric:
scorer = refit
else:
scorer = "score"
def build_refit_graph(estimator, X, y, best_params, fit_params):
X, y = to_indexable(X, y)
dsk = {}
X_name, y_name = to_keys(dsk, X, y)

best_params = "best-params-" + main_token
dsk[best_params] = (get_best_params, candidate_params_name, cv_results, scorer)
best_estimator = "best-estimator-" + main_token
if fit_params:
fit_params = (
dict,
(zip, list(fit_params.keys()), list(pluck(1, fit_params.values()))),
)
dsk[best_estimator] = (
fit_best,
clone(estimator),
best_params,
X_name,
y_name,
fit_params,
)
keys.append(best_estimator)
fit_params = map_fit_params(dsk, fit_params)
main_token = tokenize(normalize_estimator(estimator), X_name, y_name, fit_params)

return dsk, keys, n_splits
best_estimator = "best-estimator-" + main_token
if fit_params:
fit_params = (
dict,
(zip, list(fit_params.keys()), list(pluck(1, fit_params.values()))),
)
dsk[best_estimator] = (
fit_best,
clone(estimator),
best_params,
X_name,
y_name,
fit_params,
)
return dsk, [best_estimator]


def normalize_params(params):
Expand Down Expand Up @@ -1166,24 +1217,21 @@ def fit(self, X, y=None, groups=None, **fit_params):
"error_score must be the string 'raise' or a" " numeric value."
)

dsk, keys, n_splits = build_graph(
candidate_params = list(self._get_param_iterator())
dsk, keys, n_splits, _ = build_cv_graph(
estimator,
self.cv,
self.scorer_,
list(self._get_param_iterator()),
candidate_params,
X,
y,
groups,
fit_params,
y=y,
groups=groups,
fit_params=fit_params,
iid=self.iid,
refit=self.refit,
error_score=error_score,
return_train_score=self.return_train_score,
cache_cv=self.cache_cv,
multimetric=multimetric,
)
self.dask_graph_ = dsk
self.n_splits_ = n_splits

n_jobs = _normalize_n_jobs(self.n_jobs)
scheduler = dask.base.get_scheduler(scheduler=self.scheduler)
Expand All @@ -1193,21 +1241,59 @@ def fit(self, X, y=None, groups=None, **fit_params):
if scheduler is dask.threaded.get and n_jobs == 1:
scheduler = dask.local.get_sync

out = scheduler(dsk, keys, num_workers=n_jobs)
if "Client" in type(getattr(scheduler, "__self__", None)).__name__:
futures = scheduler(
dsk, keys, allow_other_workers=True, num_workers=n_jobs, sync=False
)

results = handle_deprecated_train_score(out[0], self.return_train_score)
result_map = {}
ac = as_completed(futures, with_results=True, raise_errors=False)
for batch in ac.batches():
for future, result in batch:
if future.status == "finished":
result_map[future.key] = result
else:
logger.warning("{} has failed... retrying".format(future.key))
future.retry()
ac.add(future)

out = [result_map[k] for k in keys]
else:
out = scheduler(dsk, keys, num_workers=n_jobs)

if self.iid:
weights = out[0]
scores = out[1:]
else:
weights = None
scores = out

if multimetric:
metrics = list(scorer.keys())
scorer = self.refit
else:
metrics = None
scorer = "score"

cv_results = create_cv_results(
scores, candidate_params, n_splits, error_score, weights, metrics
)

results = handle_deprecated_train_score(cv_results, self.return_train_score)
self.dask_graph_ = dsk
self.n_splits_ = n_splits
self.cv_results_ = results

if self.refit:
if self.multimetric_:
key = self.refit
else:
key = "score"
self.best_index_ = np.flatnonzero(results["rank_test_{}".format(key)] == 1)[
0
]
self.best_index_ = np.flatnonzero(
results["rank_test_{}".format(scorer)] == 1
)[0]

best_params = candidate_params[self.best_index_]
dsk, keys = build_refit_graph(estimator, X, y, best_params, fit_params)

self.best_estimator_ = out[1]
out = scheduler(dsk, keys, num_workers=n_jobs)
self.best_estimator_ = out[0]

return self

Expand Down Expand Up @@ -1584,7 +1670,6 @@ def __init__(
n_jobs=-1,
cache_cv=True,
):

super(RandomizedSearchCV, self).__init__(
estimator=estimator,
scoring=scoring,
Expand Down
9 changes: 0 additions & 9 deletions dask_ml/model_selection/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,6 @@ def cv_extract_params(cvs, keys, vals, n):
return {k: cvs.extract_param(tok, v, n) for (k, tok), v in zip(keys, vals)}


def decompress_params(fields, params):
return [{k: v for k, v in zip(fields, p) if v is not MISSING} for p in params]


def _maybe_timed(x):
"""Unpack (est, fit_time) tuples if provided"""
return x if isinstance(x, tuple) and len(x) == 2 else (x, 0.0)
Expand Down Expand Up @@ -452,11 +448,6 @@ def create_cv_results(
return results


def get_best_params(candidate_params, cv_results, scorer):
best_index = np.flatnonzero(cv_results["rank_test_{}".format(scorer)] == 1)[0]
return candidate_params[best_index]


def fit_best(estimator, params, X, y, fit_params):
estimator = copy_estimator(estimator).set_params(**params)
estimator.fit(X, y, **fit_params)
Expand Down
Loading

0 comments on commit 954d41f

Please sign in to comment.