diff --git a/benchmark/supervised/hyper_parameter_search.py b/benchmark/supervised/hyper_parameter_search.py index a2778a01..eadcd14b 100644 --- a/benchmark/supervised/hyper_parameter_search.py +++ b/benchmark/supervised/hyper_parameter_search.py @@ -86,27 +86,28 @@ def run(cfg: Mapping[str, Any], filter_pattern: str) -> None: fold_ds = ds.get_fold_ds(fid) - cprint("\n|-building train dataset\n", "blue") - train_x = tf.constant(np.array(fold_ds["train"][0])) - train_y = tf.constant(np.array(fold_ds["train"][1])) - train_ds = datasets.utils.make_sampler( - train_x, - train_y, - exp.training.params["train"], - train_aug_fns, - ) - exp.training.params["train"]["num_examples"] = train_x.shape[0] - - cprint("\n|-building val dataset\n", "blue") - val_x = tf.constant(np.array(fold_ds["val"][0])) - val_y = tf.constant(np.array(fold_ds["val"][1])) - val_ds = datasets.utils.make_sampler( - val_x, - val_y, - exp.training.params["val"], - train_aug_fns, - ) - exp.training.params["val"]["num_examples"] = val_x.shape[0] + with tf.device("CPU"): + cprint("\n|-building train dataset\n", "blue") + train_x = tf.constant(np.array(fold_ds["train"][0])) + train_y = tf.constant(np.array(fold_ds["train"][1])) + train_ds = datasets.utils.make_sampler( + train_x, + train_y, + exp.training.params["train"], + train_aug_fns, + ) + exp.training.params["train"]["num_examples"] = train_x.shape[0] + + cprint("\n|-building val dataset\n", "blue") + val_x = tf.constant(np.array(fold_ds["val"][0])) + val_y = tf.constant(np.array(fold_ds["val"][1])) + val_ds = datasets.utils.make_sampler( + val_x, + val_y, + exp.training.params["val"], + train_aug_fns, + ) + exp.training.params["val"]["num_examples"] = val_x.shape[0] # Training params callbacks = [