-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix][ShardedDDP] Properly handle .eval() mode #587
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -198,15 +198,10 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: | |
""" | ||
|
||
# Deferred initialization, or change detection | ||
needs_setup = len(self._grad_hooks) == 0 | ||
needs_setup = len(self._grad_hooks) == 0 and self.training | ||
|
||
if self.auto_refresh_trainable: | ||
# Optionally check whether the trainable parameters have changed | ||
trainable_mask = list(map(_trainable, self._all_params)) | ||
if trainable_mask != self._reference_trainable_mask: | ||
logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning") | ||
needs_setup = True | ||
self._reference_trainable_mask = trainable_mask | ||
needs_setup |= self._detect_train_change() | ||
|
||
if needs_setup: | ||
self.refresh_trainable() | ||
|
@@ -272,13 +267,13 @@ def refresh_trainable(self) -> None: | |
""" If the module trainability has changed, update all the assumptions """ | ||
|
||
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance) | ||
assert not functools.reduce( | ||
lambda x, y: x or y, self._grad_to_be_reduced, False | ||
), "Grads waiting to be reduced: {}".format(self._grad_to_be_reduced) | ||
assert not functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False), ( | ||
"Grads waiting to be reduced: {}".format(self._grad_to_be_reduced) | ||
+ "\nIf this is on purpose (grad accumulation), please use a no_sync() context" | ||
) | ||
|
||
self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params)) | ||
self._trainable_params.sort(key=lambda x: x.numel()) | ||
self._grad_to_be_reduced = [True for _ in self._trainable_params] | ||
|
||
self._trainable_param_to_rank = {} | ||
for optim in self.sharded_optimizers: | ||
|
@@ -373,7 +368,8 @@ def no_sync(self) -> Generator: | |
@torch.no_grad() | ||
def _clear_counters(self) -> None: | ||
"""Reset all the grad reduce and call counters""" | ||
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced] | ||
if self.training: | ||
self._grad_to_be_reduced = [True for _ in self._trainable_params] | ||
self._bucket_flush_callback_set = False | ||
|
||
if self.use_buckets: | ||
|
@@ -490,6 +486,9 @@ def _setup_backward_hooks(self) -> None: | |
# Go through the parameters, attach the hook | ||
self._grad_accs = [] | ||
self._manual_reduce = [] | ||
if not self.training: | ||
return | ||
|
||
for index, param in enumerate(self._trainable_params): | ||
if param.grad is not None and param.grad.requires_grad: | ||
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad") | ||
|
@@ -624,3 +623,19 @@ def _flush_reduce_calls(self) -> None: | |
bucket.sent = True | ||
|
||
self._consume_work_handles() | ||
|
||
def _detect_train_change(self) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
# Optionally check whether the trainable parameters have changed | ||
trainable_mask = list(map(_trainable, self._all_params)) | ||
|
||
# - one or more parameters trainability changed | ||
trainability_changed = trainable_mask != self._reference_trainable_mask | ||
|
||
# - the whole model is not trainable but we still have grad hooks | ||
trainability_changed |= not self.training and len(self._grad_hooks) > 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this mean that grad_hooks should be greater than 0 in eval model? Not sure I understand why this should be the case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it was meant to detect that the trainability changed, ie. we're in eval() mode but there are grad_hooks in place so we should refresh ? it's tied to the question above, I'm not sure of the reference behavior here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From my offline conversation with @blefaudeux to understand this better:
Thanks for the explanation @blefaudeux ! |
||
|
||
if trainability_changed: | ||
logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning") | ||
blefaudeux marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._reference_trainable_mask = trainable_mask | ||
|
||
return trainability_changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to remove existing hooks in eval model? Just curious, otherwise we could move this to the top of the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that it was better for correctness, in that if there's a .backward() left somewhere it still respects the eval() setting ? The documentation is not super clear, to me at least https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=eval#torch.nn.Module.train