Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP sampler #1513

Merged
merged 6 commits into from
Apr 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added flag `replace_sampler_ddp` to manually disaple sampler replacement in ddp ([#1513](https://github.com/PyTorchLightning/pytorch-lightning/pull/1513))
- Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems.
- Added learning rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))

Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class TrainerDataLoadingMixin(ABC):
train_percent_check: float
val_percent_check: float
test_percent_check: float
replace_sampler_ddp: bool

@abstractmethod
def is_overriden(self, *args):
Expand Down Expand Up @@ -88,10 +89,8 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't do anything if it's not a dataloader
if not isinstance(dataloader, DataLoader):
return dataloader

need_dist_sampler = self.use_ddp or self.use_ddp2 or self.use_tpu

if need_dist_sampler:
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_tpu)
if self.replace_sampler_ddp and need_dist_sampler:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than adding a new flag for this, is it possible to just check for the default samplers:
(isinstance(dataloader.sampler, RandomSampler) or isinstance(dataloader.sampler, SequentialSampler) or instance(dataloader.sampler, _InfiniteConstantSampler)) and need_dist_sampler?

I had a hard to find regression that happened because my custom sampler was overridden after #1425. Considering replace_sampler_ddp is set to True by default I think a lot of users will run into similar issues both in terms of regressions and new projects. E.g. users are going to expect that when they write a new sampler and pass it in the dataloader it'll be used without having to change a setting somewhere.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that, but I don't think, that this is a good idea. If I explicitly set a standard sampler and don't want it to be replaced, I can't do anything this way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. I'd prefer the trainer flag default to False at least, but let's leave it up to the opinion of others @PyTorchLightning/core-contributors


skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']

Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
benchmark: bool = False,
reload_dataloaders_every_epoch: bool = False,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
default_save_path=None, # backward compatible, todo: remove in v0.8.0
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0
Expand Down Expand Up @@ -282,6 +283,9 @@ def __init__(
rate in self.hparams.lr | self.hparams.learning_rate in the lightning module.
To use a different key, set a string instead of True with the key name.

replace_sampler_ddp: Explicitly enables or disables sampler replacement.
If not specified this will toggled automatically ddp is used

benchmark: If true enables cudnn.benchmark.

terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
Expand Down Expand Up @@ -362,6 +366,7 @@ def __init__(
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

self.auto_lr_find = auto_lr_find
self.replace_sampler_ddp = replace_sampler_ddp

self.truncated_bptt_steps = truncated_bptt_steps
self.resume_from_checkpoint = resume_from_checkpoint
Expand Down