diff --git a/chabud/datapipe.py b/chabud/datapipe.py index fe157ea..1ecb322 100644 --- a/chabud/datapipe.py +++ b/chabud/datapipe.py @@ -142,6 +142,8 @@ def __init__( # "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_0.hdf5", # "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_1.hdf5", # "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_2.hdf5", + # "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_3.hdf5", + # "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_4.hdf5", ], batch_size: int = 8, ): @@ -196,6 +198,8 @@ def setup( "data/california_0.hdf5": "f2036e129849263b66cdb9fd4769742c499a879f91a364c42bb5c953052787fc", # 3.38GB "data/california_1.hdf5": "cdb13d720fcb3115c9e1c096e22a9d652ac122c93adfcbf271d4e3684a7679af", # 3.7GB "data/california_2.hdf5": "0af569c8930348109b495a5f2768758a52a6deec85768fd70c0efd9370f84578", # 368MB + "data/california_3.hdf5": "7f2856a3cda3161c555736cf2421197ef01f33b1135406983feea8e9a3ff4c06", # 3.45GB + "data/california_4.hdf5": "fed59626b70cd7dfb1c78bb736f74ca7c5883372450ec72e389beff6b200ec9d", # 1.83GB }, hash_type="sha256", ) @@ -216,11 +220,11 @@ def setup( ) # Step 4 - Convert from xarray.Dataset to tuple of torch.Tensor objects - # Also do batching, shuffling (for train set only) and tensor stacking + # Also do shuffling (for train set only), batching, and tensor stacking self.datapipe_train = ( - dp_train.map(fn=_pre_post_mask_tuple) + dp_train.shuffle(buffer_size=100) + .map(fn=_pre_post_mask_tuple) .batch(batch_size=self.batch_size) - .in_batch_shuffle() .collate(collate_fn=_stack_tensor_collate_fn) ) self.datapipe_val = (