Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDP][Minor] Removing an assert which does not seem always accurate #625

Merged
merged 1 commit into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,9 +1664,7 @@ def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int)
if recurse:
return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore
else:
return is_bn and not isinstance(
module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)
) # type: ignore
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated, but master linting was broken because of that


pg = None
if single_rank_pg:
Expand Down
8 changes: 4 additions & 4 deletions fairscale/nn/data_parallel/sharded_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,10 @@ def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """

# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
assert not functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False), (
"Grads waiting to be reduced: {}".format(self._grad_to_be_reduced)
+ "\nIf this is on purpose (grad accumulation), please use a no_sync() context"
)
if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False):
logging.warning(
"Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
)

self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
self._trainable_params.sort(key=lambda x: x.numel())
Expand Down
12 changes: 10 additions & 2 deletions tests/nn/data_parallel/test_sharded_ddp_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ def test_mixed_types():
dist.destroy_process_group()


def test_train_eval_change():
def run_test_train_eval_change(rank, world_size, file):
# Check that ShardedDDP handles the switch from training to eval properly
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
dist.init_process_group(init_method="file://" + file, backend="gloo", rank=rank, world_size=world_size)

model = _get_mlp()
model.train()
Expand All @@ -288,6 +288,14 @@ def test_train_eval_change():
dist.destroy_process_group()


def test_train_eval_change():
world_size = 4
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making the test case truly distributed, but still not seeing the issue which was reported in a somewhat similar situation. My guess is that there are some changes on the lightning side which make it a little more than just .train()/.eval(), but crashing on users is probably not a good idea in that case. Worst case is a lost step() worth of gradients

temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_test_train_eval_change, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)


def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
# Check that the wrapped module can change devices
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
Expand Down