From 560e5837c459106b2ad599b50aaf3ca7de846d33 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 13 Aug 2020 02:01:18 +0000 Subject: [PATCH] fixed subsampling for ddp training --- fastmri/data/mri_data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastmri/data/mri_data.py b/fastmri/data/mri_data.py index bf789875..10f69ce9 100644 --- a/fastmri/data/mri_data.py +++ b/fastmri/data/mri_data.py @@ -114,7 +114,7 @@ class SliceDataset(Dataset): what fraction of the volumes should be loaded. """ - def __init__(self, root, transform, challenge, sample_rate=1): + def __init__(self, root, transform, challenge, sample_rate=1, seed=0): if challenge not in ("singlecoil", "multicoil"): raise ValueError('challenge should be either "singlecoil" or "multicoil"') @@ -126,6 +126,7 @@ def __init__(self, root, transform, challenge, sample_rate=1): files = list(pathlib.Path(root).iterdir()) if sample_rate < 1: + random.seed(seed) # get the same files in every process random.shuffle(files) num_files = round(len(files) * sample_rate) files = files[:num_files]