Skip to content

Commit

Permalink
Wrap tf.data.Dataset in CPU context to avoid GPU OOM errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
owenvallis committed Sep 6, 2023
1 parent 7dcb23d commit 5774152
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions benchmark/supervised/hyper_parameter_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,28 @@ def run(cfg: Mapping[str, Any], filter_pattern: str) -> None:

fold_ds = ds.get_fold_ds(fid)

cprint("\n|-building train dataset\n", "blue")
train_x = tf.constant(np.array(fold_ds["train"][0]))
train_y = tf.constant(np.array(fold_ds["train"][1]))
train_ds = datasets.utils.make_sampler(
train_x,
train_y,
exp.training.params["train"],
train_aug_fns,
)
exp.training.params["train"]["num_examples"] = train_x.shape[0]

cprint("\n|-building val dataset\n", "blue")
val_x = tf.constant(np.array(fold_ds["val"][0]))
val_y = tf.constant(np.array(fold_ds["val"][1]))
val_ds = datasets.utils.make_sampler(
val_x,
val_y,
exp.training.params["val"],
train_aug_fns,
)
exp.training.params["val"]["num_examples"] = val_x.shape[0]
with tf.device("CPU"):
cprint("\n|-building train dataset\n", "blue")
train_x = tf.constant(np.array(fold_ds["train"][0]))
train_y = tf.constant(np.array(fold_ds["train"][1]))
train_ds = datasets.utils.make_sampler(
train_x,
train_y,
exp.training.params["train"],
train_aug_fns,
)
exp.training.params["train"]["num_examples"] = train_x.shape[0]

cprint("\n|-building val dataset\n", "blue")
val_x = tf.constant(np.array(fold_ds["val"][0]))
val_y = tf.constant(np.array(fold_ds["val"][1]))
val_ds = datasets.utils.make_sampler(
val_x,
val_y,
exp.training.params["val"],
train_aug_fns,
)
exp.training.params["val"]["num_examples"] = val_x.shape[0]

# Training params
callbacks = [
Expand Down

0 comments on commit 5774152

Please sign in to comment.