-
-
Notifications
You must be signed in to change notification settings - Fork 256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gather CV Results as Completed #433
Changes from all commits
ce3e404
6b3688a
9a8585d
6e4123e
a506d58
1fa69e3
578c5f8
12892fa
2c1d005
c9d3a75
e81f9e5
000059d
6c83b91
ee34b12
2d0627c
fa77f19
a0ad0d6
184e728
915cae3
406f6be
a587fa2
66f245e
cfdc7a8
0871f35
61ea9dc
111d77a
1dea132
d6d15ca
df86170
577d987
0d8dbc6
cc70fe0
7967c5c
8105ea8
fc84640
bb54522
6d5e242
c55513f
0ba7619
28490ab
751b824
39a2343
673d316
7df0d20
bfa6dff
db8e153
71db4b6
f619766
6de0af0
fa0f65e
946ba5b
b00b558
19cf8da
6171a97
f6e0160
d6441e7
922d27c
a72388e
a609ce5
a0fdbd8
7c450a9
b2d2cf9
ca5b07e
cde1f5c
63ef4c3
7a455ce
8dc53ab
6c4150e
250e08c
b962af2
f206bc5
6a18f30
fa9fb17
bfdea5e
8ef5338
869e8b9
9c034ab
e482ed9
9008744
7623003
9da5b8f
e86bc1f
3db7c28
687a83b
1eb113c
163f3bd
d5b95e4
33f6fd2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import logging | ||
import numbers | ||
from collections import defaultdict | ||
from itertools import repeat | ||
|
@@ -11,6 +12,7 @@ | |
import packaging.version | ||
from dask.base import tokenize | ||
from dask.delayed import delayed | ||
from dask.distributed import as_completed | ||
from dask.utils import derived_from | ||
from sklearn import model_selection | ||
from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone, is_classifier | ||
|
@@ -42,19 +44,19 @@ | |
cv_extract_params, | ||
cv_n_samples, | ||
cv_split, | ||
decompress_params, | ||
feature_union, | ||
feature_union_concat, | ||
fit, | ||
fit_and_score, | ||
fit_best, | ||
fit_transform, | ||
get_best_params, | ||
pipeline, | ||
score, | ||
) | ||
from .utils import DeprecationDict, is_dask_collection, to_indexable, to_keys, unzip | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
try: | ||
from cytoolz import get, pluck | ||
except ImportError: # pragma: no cover | ||
|
@@ -63,7 +65,6 @@ | |
|
||
__all__ = ["GridSearchCV", "RandomizedSearchCV"] | ||
|
||
|
||
if SK_VERSION <= packaging.version.parse("0.21.dev0"): | ||
|
||
_RETURN_TRAIN_SCORE_DEFAULT = "warn" | ||
|
@@ -102,6 +103,19 @@ def __call__(self, est): | |
return self.token if c == 0 else self.token + str(c) | ||
|
||
|
||
def map_fit_params(dsk, fit_params): | ||
if fit_params: | ||
# A mapping of {name: (name, graph-key)} | ||
param_values = to_indexable(*fit_params.values(), allow_scalars=True) | ||
fit_params = { | ||
k: (k, v) for (k, v) in zip(fit_params, to_keys(dsk, *param_values)) | ||
} | ||
else: | ||
fit_params = {} | ||
|
||
return fit_params | ||
|
||
|
||
def build_graph( | ||
estimator, | ||
cv, | ||
|
@@ -118,7 +132,68 @@ def build_graph( | |
cache_cv=True, | ||
multimetric=False, | ||
): | ||
# This is provided for compatibility with TPOT. Remove | ||
# once TPOT is updated and requires a dask-ml>=0.13.0 | ||
def decompress_params(fields, params): | ||
return [{k: v for k, v in zip(fields, p) if v is not MISSING} for p in params] | ||
|
||
fields, tokens, params = normalize_params(candidate_params) | ||
dsk, keys, n_splits, main_token = build_cv_graph( | ||
estimator, | ||
cv, | ||
scorer, | ||
candidate_params, | ||
X, | ||
y=y, | ||
groups=groups, | ||
fit_params=fit_params, | ||
iid=iid, | ||
error_score=error_score, | ||
return_train_score=return_train_score, | ||
cache_cv=cache_cv, | ||
) | ||
cv_name = "cv-split-" + main_token | ||
if iid: | ||
weights = "cv-n-samples-" + main_token | ||
dsk[weights] = (cv_n_samples, cv_name) | ||
scores = keys[1:] | ||
else: | ||
scores = keys | ||
|
||
cv_results = "cv-results-" + main_token | ||
candidate_params_name = "cv-parameters-" + main_token | ||
dsk[candidate_params_name] = (decompress_params, fields, params) | ||
if multimetric: | ||
metrics = list(scorer.keys()) | ||
else: | ||
metrics = None | ||
dsk[cv_results] = ( | ||
create_cv_results, | ||
scores, | ||
candidate_params_name, | ||
n_splits, | ||
error_score, | ||
weights, | ||
metrics, | ||
) | ||
keys = [cv_results] | ||
return dsk, keys, n_splits | ||
|
||
|
||
def build_cv_graph( | ||
estimator, | ||
cv, | ||
scorer, | ||
candidate_params, | ||
X, | ||
y=None, | ||
groups=None, | ||
fit_params=None, | ||
iid=True, | ||
error_score="raise", | ||
return_train_score=_RETURN_TRAIN_SCORE_DEFAULT, | ||
cache_cv=True, | ||
): | ||
X, y, groups = to_indexable(X, y, groups) | ||
cv = check_cv(cv, y, is_classifier(estimator)) | ||
# "pairwise" estimators require a different graph for CV splitting | ||
|
@@ -128,14 +203,7 @@ def build_graph( | |
X_name, y_name, groups_name = to_keys(dsk, X, y, groups) | ||
n_splits = compute_n_splits(cv, X, y, groups) | ||
|
||
if fit_params: | ||
# A mapping of {name: (name, graph-key)} | ||
param_values = to_indexable(*fit_params.values(), allow_scalars=True) | ||
fit_params = { | ||
k: (k, v) for (k, v) in zip(fit_params, to_keys(dsk, *param_values)) | ||
} | ||
else: | ||
fit_params = {} | ||
fit_params = map_fit_params(dsk, fit_params) | ||
|
||
fields, tokens, params = normalize_params(candidate_params) | ||
main_token = tokenize( | ||
|
@@ -176,50 +244,33 @@ def build_graph( | |
scorer, | ||
return_train_score, | ||
) | ||
keys = [weights] + scores if weights else scores | ||
return dsk, keys, n_splits, main_token | ||
|
||
cv_results = "cv-results-" + main_token | ||
candidate_params_name = "cv-parameters-" + main_token | ||
dsk[candidate_params_name] = (decompress_params, fields, params) | ||
if multimetric: | ||
metrics = list(scorer.keys()) | ||
else: | ||
metrics = None | ||
dsk[cv_results] = ( | ||
create_cv_results, | ||
scores, | ||
candidate_params_name, | ||
n_splits, | ||
error_score, | ||
weights, | ||
metrics, | ||
) | ||
keys = [cv_results] | ||
|
||
if refit: | ||
if multimetric: | ||
scorer = refit | ||
else: | ||
scorer = "score" | ||
def build_refit_graph(estimator, X, y, best_params, fit_params): | ||
X, y = to_indexable(X, y) | ||
dsk = {} | ||
X_name, y_name = to_keys(dsk, X, y) | ||
|
||
best_params = "best-params-" + main_token | ||
dsk[best_params] = (get_best_params, candidate_params_name, cv_results, scorer) | ||
best_estimator = "best-estimator-" + main_token | ||
if fit_params: | ||
fit_params = ( | ||
dict, | ||
(zip, list(fit_params.keys()), list(pluck(1, fit_params.values()))), | ||
) | ||
dsk[best_estimator] = ( | ||
fit_best, | ||
clone(estimator), | ||
best_params, | ||
X_name, | ||
y_name, | ||
fit_params, | ||
) | ||
keys.append(best_estimator) | ||
fit_params = map_fit_params(dsk, fit_params) | ||
main_token = tokenize(normalize_estimator(estimator), X_name, y_name, fit_params) | ||
|
||
return dsk, keys, n_splits | ||
best_estimator = "best-estimator-" + main_token | ||
if fit_params: | ||
fit_params = ( | ||
dict, | ||
(zip, list(fit_params.keys()), list(pluck(1, fit_params.values()))), | ||
) | ||
dsk[best_estimator] = ( | ||
fit_best, | ||
clone(estimator), | ||
best_params, | ||
X_name, | ||
y_name, | ||
fit_params, | ||
) | ||
return dsk, [best_estimator] | ||
|
||
|
||
def normalize_params(params): | ||
|
@@ -1166,24 +1217,21 @@ def fit(self, X, y=None, groups=None, **fit_params): | |
"error_score must be the string 'raise' or a" " numeric value." | ||
) | ||
|
||
dsk, keys, n_splits = build_graph( | ||
candidate_params = list(self._get_param_iterator()) | ||
dsk, keys, n_splits, _ = build_cv_graph( | ||
estimator, | ||
self.cv, | ||
self.scorer_, | ||
list(self._get_param_iterator()), | ||
candidate_params, | ||
X, | ||
y, | ||
groups, | ||
fit_params, | ||
y=y, | ||
groups=groups, | ||
fit_params=fit_params, | ||
iid=self.iid, | ||
refit=self.refit, | ||
error_score=error_score, | ||
return_train_score=self.return_train_score, | ||
cache_cv=self.cache_cv, | ||
multimetric=multimetric, | ||
) | ||
self.dask_graph_ = dsk | ||
self.n_splits_ = n_splits | ||
|
||
n_jobs = _normalize_n_jobs(self.n_jobs) | ||
scheduler = dask.base.get_scheduler(scheduler=self.scheduler) | ||
|
@@ -1193,21 +1241,59 @@ def fit(self, X, y=None, groups=None, **fit_params): | |
if scheduler is dask.threaded.get and n_jobs == 1: | ||
scheduler = dask.local.get_sync | ||
|
||
out = scheduler(dsk, keys, num_workers=n_jobs) | ||
if "Client" in type(getattr(scheduler, "__self__", None)).__name__: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this approach be heavier on communication? Or is it the same amount, just spread out over time? My real question is should we provide an option to disable this behavior? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From what I can see I don't see why this would add additional overhead. This "should" submit effectively the same task graph in the traditional dask scheduler as well as distributed. The difference here is that the distributed scheduler allows you to submit the graph asynchronously and return back a bunch of futures with you can then yield the results as they return. Given this I don't see why it would be preferable to disable this feature if a distributed client is present. |
||
futures = scheduler( | ||
dsk, keys, allow_other_workers=True, num_workers=n_jobs, sync=False | ||
) | ||
|
||
results = handle_deprecated_train_score(out[0], self.return_train_score) | ||
result_map = {} | ||
ac = as_completed(futures, with_results=True, raise_errors=False) | ||
for batch in ac.batches(): | ||
for future, result in batch: | ||
if future.status == "finished": | ||
result_map[future.key] = result | ||
else: | ||
logger.warning("{} has failed... retrying".format(future.key)) | ||
future.retry() | ||
ac.add(future) | ||
|
||
out = [result_map[k] for k in keys] | ||
else: | ||
out = scheduler(dsk, keys, num_workers=n_jobs) | ||
|
||
if self.iid: | ||
weights = out[0] | ||
scores = out[1:] | ||
else: | ||
weights = None | ||
scores = out | ||
|
||
if multimetric: | ||
metrics = list(scorer.keys()) | ||
scorer = self.refit | ||
else: | ||
metrics = None | ||
scorer = "score" | ||
|
||
cv_results = create_cv_results( | ||
scores, candidate_params, n_splits, error_score, weights, metrics | ||
) | ||
|
||
results = handle_deprecated_train_score(cv_results, self.return_train_score) | ||
self.dask_graph_ = dsk | ||
self.n_splits_ = n_splits | ||
self.cv_results_ = results | ||
|
||
if self.refit: | ||
if self.multimetric_: | ||
key = self.refit | ||
else: | ||
key = "score" | ||
self.best_index_ = np.flatnonzero(results["rank_test_{}".format(key)] == 1)[ | ||
0 | ||
] | ||
self.best_index_ = np.flatnonzero( | ||
results["rank_test_{}".format(scorer)] == 1 | ||
)[0] | ||
|
||
best_params = candidate_params[self.best_index_] | ||
dsk, keys = build_refit_graph(estimator, X, y, best_params, fit_params) | ||
|
||
self.best_estimator_ = out[1] | ||
out = scheduler(dsk, keys, num_workers=n_jobs) | ||
self.best_estimator_ = out[0] | ||
|
||
return self | ||
|
||
|
@@ -1584,7 +1670,6 @@ def __init__( | |
n_jobs=-1, | ||
cache_cv=True, | ||
): | ||
|
||
super(RandomizedSearchCV, self).__init__( | ||
estimator=estimator, | ||
scoring=scoring, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we provide a compatibility shim with the old name? I believe that TPOT is using
build_graph
https://github.com/EpistasisLab/tpot/blob/b626271e6b5896a73fb9d7d29bebc7aa9100772e/tpot/gp_deap.py#L429There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can just add back
build_graph
with a comment that it's going to be removed (not even a deprecation warning). This is a private module, but I was lazy when adding dask support to TPOT. Once this is in, I can cut a release and make a PR to tpot updating their call.