Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix][ShardedDDP] Properly handle .eval() mode #587

Merged
merged 3 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions fairscale/nn/data_parallel/sharded_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,10 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
"""

# Deferred initialization, or change detection
needs_setup = len(self._grad_hooks) == 0
needs_setup = len(self._grad_hooks) == 0 and self.training

if self.auto_refresh_trainable:
# Optionally check whether the trainable parameters have changed
trainable_mask = list(map(_trainable, self._all_params))
if trainable_mask != self._reference_trainable_mask:
logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning")
needs_setup = True
self._reference_trainable_mask = trainable_mask
needs_setup |= self._detect_train_change()

if needs_setup:
self.refresh_trainable()
Expand Down Expand Up @@ -272,13 +267,13 @@ def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """

# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
assert not functools.reduce(
lambda x, y: x or y, self._grad_to_be_reduced, False
), "Grads waiting to be reduced: {}".format(self._grad_to_be_reduced)
assert not functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False), (
"Grads waiting to be reduced: {}".format(self._grad_to_be_reduced)
+ "\nIf this is on purpose (grad accumulation), please use a no_sync() context"
)

self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
self._trainable_params.sort(key=lambda x: x.numel())
self._grad_to_be_reduced = [True for _ in self._trainable_params]

self._trainable_param_to_rank = {}
for optim in self.sharded_optimizers:
Expand Down Expand Up @@ -373,7 +368,8 @@ def no_sync(self) -> Generator:
@torch.no_grad()
def _clear_counters(self) -> None:
"""Reset all the grad reduce and call counters"""
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
if self.training:
self._grad_to_be_reduced = [True for _ in self._trainable_params]
self._bucket_flush_callback_set = False

if self.use_buckets:
Expand Down Expand Up @@ -490,6 +486,9 @@ def _setup_backward_hooks(self) -> None:
# Go through the parameters, attach the hook
self._grad_accs = []
self._manual_reduce = []
if not self.training:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to remove existing hooks in eval model? Just curious, otherwise we could move this to the top of the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that it was better for correctness, in that if there's a .backward() left somewhere it still respects the eval() setting ? The documentation is not super clear, to me at least https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=eval#torch.nn.Module.train

return

for index, param in enumerate(self._trainable_params):
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
Expand Down Expand Up @@ -624,3 +623,19 @@ def _flush_reduce_calls(self) -> None:
bucket.sent = True

self._consume_work_handles()

def _detect_train_change(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

# Optionally check whether the trainable parameters have changed
trainable_mask = list(map(_trainable, self._all_params))

# - one or more parameters trainability changed
trainability_changed = trainable_mask != self._reference_trainable_mask

# - the whole model is not trainable but we still have grad hooks
trainability_changed |= not self.training and len(self._grad_hooks) > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean that grad_hooks should be greater than 0 in eval model? Not sure I understand why this should be the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was meant to detect that the trainability changed, ie. we're in eval() mode but there are grad_hooks in place so we should refresh ? it's tied to the question above, I'm not sure of the reference behavior here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my offline conversation with @blefaudeux to understand this better:

  • We can't detect when a module switches from train->eval unless we use the presence of hooks as an indicator.
  • We refresh trainable 1) at the beginning 2) when params changes their requires_grad property 3) train<->eval switch.

Thanks for the explanation @blefaudeux !


if trainability_changed:
logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning")
blefaudeux marked this conversation as resolved.
Show resolved Hide resolved
self._reference_trainable_mask = trainable_mask

return trainability_changed
23 changes: 23 additions & 0 deletions tests/nn/data_parallel/test_sharded_ddp_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,29 @@ def test_mixed_types():
dist.destroy_process_group()


def test_train_eval_change():
# Check that ShardedDDP handles the switch from training to eval properly
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)

model = _get_mlp()
model.train()
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
model = ShardedDataParallel(model, optimizer)
input_tensor = torch.rand((2, 2))
loss = model(input_tensor).sum()
loss.backward() # make sure that the gradients are reduced

model = model.eval()
blefaudeux marked this conversation as resolved.
Show resolved Hide resolved
_ = model(input_tensor)
_ = model(input_tensor)

model = model.train()
_ = model(input_tensor)
_ = model(input_tensor)

dist.destroy_process_group()


def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
# Check that the wrapped module can change devices
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
Expand Down