diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index e848e09725e8a..0ca90e5962728 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -91,15 +91,12 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: ) dl_args['shuffle'] = False else: - if train: - sampler = DistributedSampler(dataloader.dataset) - dl_args['shuffle'] = False - else: - sampler = SequentialSampler(dataloader.dataset) + sampler = DistributedSampler(dataloader.dataset) + dl_args['shuffle'] = False dl_args['sampler'] = sampler - dataloader = DataLoader(**dl_args) + return dataloader def reset_train_dataloader(self, model: LightningModule) -> None: