Skip to content

Commit

Permalink
[fix][ShardedDDP] Properly handle .eval() mode (#587)
Browse files Browse the repository at this point in the history
* Properly handle .train() and .eval() modes
* showing that the unit test works, now fixed
* code review
  • Loading branch information
blefaudeux committed Apr 7, 2021
1 parent e89a191 commit ce1f2ce
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 12 deletions.
41 changes: 29 additions & 12 deletions fairscale/nn/data_parallel/sharded_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,10 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
"""

# Deferred initialization, or change detection
needs_setup = len(self._grad_hooks) == 0
needs_setup = len(self._grad_hooks) == 0 and self.training

if self.auto_refresh_trainable:
# Optionally check whether the trainable parameters have changed
trainable_mask = list(map(_trainable, self._all_params))
if trainable_mask != self._reference_trainable_mask:
logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning")
needs_setup = True
self._reference_trainable_mask = trainable_mask
needs_setup |= self._detect_train_change()

if needs_setup:
self.refresh_trainable()
Expand Down Expand Up @@ -272,13 +267,13 @@ def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """

# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
assert not functools.reduce(
lambda x, y: x or y, self._grad_to_be_reduced, False
), "Grads waiting to be reduced: {}".format(self._grad_to_be_reduced)
assert not functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False), (
"Grads waiting to be reduced: {}".format(self._grad_to_be_reduced)
+ "\nIf this is on purpose (grad accumulation), please use a no_sync() context"
)

self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
self._trainable_params.sort(key=lambda x: x.numel())
self._grad_to_be_reduced = [True for _ in self._trainable_params]

self._trainable_param_to_rank = {}
for optim in self.sharded_optimizers:
Expand Down Expand Up @@ -373,7 +368,8 @@ def no_sync(self) -> Generator:
@torch.no_grad()
def _clear_counters(self) -> None:
"""Reset all the grad reduce and call counters"""
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
if self.training:
self._grad_to_be_reduced = [True for _ in self._trainable_params]
self._bucket_flush_callback_set = False

if self.use_buckets:
Expand Down Expand Up @@ -490,6 +486,9 @@ def _setup_backward_hooks(self) -> None:
# Go through the parameters, attach the hook
self._grad_accs = []
self._manual_reduce = []
if not self.training:
return

for index, param in enumerate(self._trainable_params):
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
Expand Down Expand Up @@ -624,3 +623,21 @@ def _flush_reduce_calls(self) -> None:
bucket.sent = True

self._consume_work_handles()

def _detect_train_change(self) -> bool:
# Optionally check whether the trainable parameters have changed
trainable_mask = list(map(_trainable, self._all_params))

# - one or more parameters trainability changed
trainability_changed = trainable_mask != self._reference_trainable_mask

# - the whole model is not trainable but we still have grad hooks
trainability_changed |= not self.training and len(self._grad_hooks) > 0

if trainability_changed:
logging.warning(
"ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze."
)
self._reference_trainable_mask = trainable_mask

return trainability_changed
26 changes: 26 additions & 0 deletions tests/nn/data_parallel/test_sharded_ddp_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,32 @@ def test_mixed_types():
dist.destroy_process_group()


def test_train_eval_change():
# Check that ShardedDDP handles the switch from training to eval properly
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)

model = _get_mlp()
model.train()
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
model = ShardedDataParallel(model, optimizer)
input_tensor = torch.rand((2, 2))
loss = model(input_tensor).sum()
loss.backward() # make sure that the gradients are reduced

# Wipe the gradients and switch to eval mode
model.zero_grad()
model.eval()
_ = model(input_tensor)
assert next(model.parameters()).grad is None or torch.norm(next(model.parameters()).grad) < 1e-6

# Get back to training
model = model.train()
model(input_tensor).sum().backward()
assert torch.norm(next(model.parameters()).grad) > 0.0

dist.destroy_process_group()


def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
# Check that the wrapped module can change devices
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
Expand Down

0 comments on commit ce1f2ce

Please sign in to comment.