From 979ffa0bb645ce24ca581ac59063b6e76a1ef593 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Wed, 7 Apr 2021 18:10:20 +0000 Subject: [PATCH 1/3] Properly handle .train() and .eval() modes --- fairscale/nn/data_parallel/sharded_ddp.py | 32 +++++++++++++------ .../test_sharded_ddp_features.py | 22 +++++++++++++ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/fairscale/nn/data_parallel/sharded_ddp.py b/fairscale/nn/data_parallel/sharded_ddp.py index a57d3dedc..c20839bbb 100644 --- a/fairscale/nn/data_parallel/sharded_ddp.py +++ b/fairscale/nn/data_parallel/sharded_ddp.py @@ -198,15 +198,10 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: """ # Deferred initialization, or change detection - needs_setup = len(self._grad_hooks) == 0 + needs_setup = len(self._grad_hooks) == 0 and self.training if self.auto_refresh_trainable: - # Optionally check whether the trainable parameters have changed - trainable_mask = list(map(_trainable, self._all_params)) - if trainable_mask != self._reference_trainable_mask: - logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning") - needs_setup = True - self._reference_trainable_mask = trainable_mask + needs_setup |= self._detect_train_change() if needs_setup: self.refresh_trainable() @@ -278,7 +273,6 @@ def refresh_trainable(self) -> None: self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params)) self._trainable_params.sort(key=lambda x: x.numel()) - self._grad_to_be_reduced = [True for _ in self._trainable_params] self._trainable_param_to_rank = {} for optim in self.sharded_optimizers: @@ -373,7 +367,8 @@ def no_sync(self) -> Generator: @torch.no_grad() def _clear_counters(self) -> None: """Reset all the grad reduce and call counters""" - self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced] + if self.training: + self._grad_to_be_reduced = [True for _ in self._trainable_params] self._bucket_flush_callback_set = False if self.use_buckets: @@ -490,6 +485,9 @@ def _setup_backward_hooks(self) -> None: # Go through the parameters, attach the hook self._grad_accs = [] self._manual_reduce = [] + if not self.training: + return + for index, param in enumerate(self._trainable_params): if param.grad is not None and param.grad.requires_grad: raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad") @@ -624,3 +622,19 @@ def _flush_reduce_calls(self) -> None: bucket.sent = True self._consume_work_handles() + + def _detect_train_change(self) -> bool: + # Optionally check whether the trainable parameters have changed + trainable_mask = list(map(_trainable, self._all_params)) + + # - one or more parameters trainability changed + trainability_changed = trainable_mask != self._reference_trainable_mask + + # - the whole model is not trainable but we still have grad hooks + trainability_changed |= not self.training and len(self._grad_hooks) > 0 + + if trainability_changed: + logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning") + self._reference_trainable_mask = trainable_mask + + return trainability_changed diff --git a/tests/nn/data_parallel/test_sharded_ddp_features.py b/tests/nn/data_parallel/test_sharded_ddp_features.py index f10a97e27..e7c259700 100644 --- a/tests/nn/data_parallel/test_sharded_ddp_features.py +++ b/tests/nn/data_parallel/test_sharded_ddp_features.py @@ -262,6 +262,28 @@ def test_mixed_types(): dist.destroy_process_group() +def test_train_eval_change(): + # Check that ShardedDDP handles the switch from training to eval properly + dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) + + model = _get_mlp() + model.train() + optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) + model = ShardedDataParallel(model, optimizer) + input_tensor = torch.rand((2, 2)) + _ = model(input_tensor) + + model = model.eval() + _ = model(input_tensor) + _ = model(input_tensor) + + model = model.train() + _ = model(input_tensor) + _ = model(input_tensor) + + dist.destroy_process_group() + + def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): # Check that the wrapped module can change devices dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) From eaf11a287506c847600fdf3467129c6bf2acdcdb Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Wed, 7 Apr 2021 18:40:02 +0000 Subject: [PATCH 2/3] showing that the unit test works, now fixed --- fairscale/nn/data_parallel/sharded_ddp.py | 7 ++++--- tests/nn/data_parallel/test_sharded_ddp_features.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/fairscale/nn/data_parallel/sharded_ddp.py b/fairscale/nn/data_parallel/sharded_ddp.py index c20839bbb..7e17ea3e0 100644 --- a/fairscale/nn/data_parallel/sharded_ddp.py +++ b/fairscale/nn/data_parallel/sharded_ddp.py @@ -267,9 +267,10 @@ def refresh_trainable(self) -> None: """ If the module trainability has changed, update all the assumptions """ # Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance) - assert not functools.reduce( - lambda x, y: x or y, self._grad_to_be_reduced, False - ), "Grads waiting to be reduced: {}".format(self._grad_to_be_reduced) + assert not functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False), ( + "Grads waiting to be reduced: {}".format(self._grad_to_be_reduced) + + "\nIf this is on purpose (grad accumulation), please use a no_sync() context" + ) self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params)) self._trainable_params.sort(key=lambda x: x.numel()) diff --git a/tests/nn/data_parallel/test_sharded_ddp_features.py b/tests/nn/data_parallel/test_sharded_ddp_features.py index e7c259700..fc4d079d6 100644 --- a/tests/nn/data_parallel/test_sharded_ddp_features.py +++ b/tests/nn/data_parallel/test_sharded_ddp_features.py @@ -271,7 +271,8 @@ def test_train_eval_change(): optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) model = ShardedDataParallel(model, optimizer) input_tensor = torch.rand((2, 2)) - _ = model(input_tensor) + loss = model(input_tensor).sum() + loss.backward() # make sure that the gradients are reduced model = model.eval() _ = model(input_tensor) From df5d924888c419e58e58dd8ad8ef4041c2672c9f Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Wed, 7 Apr 2021 20:28:06 +0000 Subject: [PATCH 3/3] code review --- fairscale/nn/data_parallel/sharded_ddp.py | 4 +++- tests/nn/data_parallel/test_sharded_ddp_features.py | 11 +++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/fairscale/nn/data_parallel/sharded_ddp.py b/fairscale/nn/data_parallel/sharded_ddp.py index 7e17ea3e0..2dc70e212 100644 --- a/fairscale/nn/data_parallel/sharded_ddp.py +++ b/fairscale/nn/data_parallel/sharded_ddp.py @@ -635,7 +635,9 @@ def _detect_train_change(self) -> bool: trainability_changed |= not self.training and len(self._grad_hooks) > 0 if trainability_changed: - logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning") + logging.warning( + "ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze." + ) self._reference_trainable_mask = trainable_mask return trainability_changed diff --git a/tests/nn/data_parallel/test_sharded_ddp_features.py b/tests/nn/data_parallel/test_sharded_ddp_features.py index fc4d079d6..078a20893 100644 --- a/tests/nn/data_parallel/test_sharded_ddp_features.py +++ b/tests/nn/data_parallel/test_sharded_ddp_features.py @@ -274,13 +274,16 @@ def test_train_eval_change(): loss = model(input_tensor).sum() loss.backward() # make sure that the gradients are reduced - model = model.eval() - _ = model(input_tensor) + # Wipe the gradients and switch to eval mode + model.zero_grad() + model.eval() _ = model(input_tensor) + assert next(model.parameters()).grad is None or torch.norm(next(model.parameters()).grad) < 1e-6 + # Get back to training model = model.train() - _ = model(input_tensor) - _ = model(input_tensor) + model(input_tensor).sum().backward() + assert torch.norm(next(model.parameters()).grad) > 0.0 dist.destroy_process_group()