Skip to content

Commit

Permalink
🍱 Extra data and pre-batch shuffle on train datapipe (#14)
Browse files Browse the repository at this point in the history
* 🍱 Extra datasets california_3.hdf5 and california_4.hdf5

More sample imagery datasets for training, added in https://huggingface.co/datasets/chabud-team/chabud-extra/commit/7da36fcb240ef39beed1f877acc837b98746f35b.

* 👔 Shuffle chips before batching instead of in-batch shuffling

Randomizing the order of the chips before creating mini-batches, because the train_eval.hdf5 contains all the non-zero labels while california_*.hdf5 contain all zero labels. The shuffling causes a roughly 2x slowdown from 1it/s to 2it/s.
  • Loading branch information
weiji14 committed Jun 1, 2023
1 parent 4f2f232 commit 6ca3381
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions chabud/datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def __init__(
# "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_0.hdf5",
# "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_1.hdf5",
# "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_2.hdf5",
# "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_3.hdf5",
# "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_4.hdf5",
],
batch_size: int = 8,
):
Expand Down Expand Up @@ -196,6 +198,8 @@ def setup(
"data/california_0.hdf5": "f2036e129849263b66cdb9fd4769742c499a879f91a364c42bb5c953052787fc", # 3.38GB
"data/california_1.hdf5": "cdb13d720fcb3115c9e1c096e22a9d652ac122c93adfcbf271d4e3684a7679af", # 3.7GB
"data/california_2.hdf5": "0af569c8930348109b495a5f2768758a52a6deec85768fd70c0efd9370f84578", # 368MB
"data/california_3.hdf5": "7f2856a3cda3161c555736cf2421197ef01f33b1135406983feea8e9a3ff4c06", # 3.45GB
"data/california_4.hdf5": "fed59626b70cd7dfb1c78bb736f74ca7c5883372450ec72e389beff6b200ec9d", # 1.83GB
},
hash_type="sha256",
)
Expand All @@ -216,11 +220,11 @@ def setup(
)

# Step 4 - Convert from xarray.Dataset to tuple of torch.Tensor objects
# Also do batching, shuffling (for train set only) and tensor stacking
# Also do shuffling (for train set only), batching, and tensor stacking
self.datapipe_train = (
dp_train.map(fn=_pre_post_mask_tuple)
dp_train.shuffle(buffer_size=100)
.map(fn=_pre_post_mask_tuple)
.batch(batch_size=self.batch_size)
.in_batch_shuffle()
.collate(collate_fn=_stack_tensor_collate_fn)
)
self.datapipe_val = (
Expand Down

0 comments on commit 6ca3381

Please sign in to comment.