Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature-fix-refactor][ShardedDDP] Make it possible to change trainability graph on the fly #369

Merged
merged 18 commits into from
Feb 12, 2021

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Feb 5, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #368 and #354, required by VISSL and possibly others

  • make it possible to refresh the grad/reduce plan on the fly, if parts of the model are frozen for instance
  • intercept zero_grad(set_to_none=True), which would have broken the buckets
  • minor cleanup in that the work handles are moved out of the optimizer wrapper (used to be jointly used for broadcast and reduce, not anymore), which in turn means that a couple of structures which used to point to the different optimizers present are now streamlined into only per-device
  • testing all known combinations with respect to DDP: AMP/accumulation/change trainability over time

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 5, 2021
@blefaudeux blefaudeux requested a review from anj-s February 5, 2021 23:29
@@ -565,24 +564,6 @@ def _broadcast_params(self) -> None:
if last_work_handle:
last_work_handle.wait()

def _consume_work_handles(self) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the tensor views change (broadcast buckets use tensor views) this was not being used in OSS anymore, just in ShardedDDP, I figured it was cleaner to move it there

@blefaudeux blefaudeux marked this pull request as draft February 5, 2021 23:58
@blefaudeux
Copy link
Contributor Author

not handling multiple optimizers properly, fixing that

@blefaudeux blefaudeux marked this pull request as ready for review February 6, 2021 00:12
@blefaudeux blefaudeux marked this pull request as draft February 6, 2021 00:15
@blefaudeux blefaudeux marked this pull request as ready for review February 6, 2021 01:21
@blefaudeux blefaudeux changed the title [feature-fix][ShardedDDP] Make it possible to change trainability over time [feature-fix-refactor][ShardedDDP] Make it possible to change trainability over time Feb 6, 2021
@@ -145,55 +171,29 @@ def check_parity(amp: bool):
module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
)

next(model.parameters()).requires_grad = False # Test non-trainable parameters
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole point of this PR.. check that you can change the training graph after instantiation and still get the correct results

@blefaudeux
Copy link
Contributor Author

blefaudeux commented Feb 6, 2021

FB only: tested with RegNet256 f249307456 f2578230350

@blefaudeux blefaudeux changed the title [feature-fix-refactor][ShardedDDP] Make it possible to change trainability over time [feature-fix-refactor][ShardedDDP] Make it possible to change trainability graph on the fly Feb 6, 2021
Copy link
Contributor

@msbaines msbaines left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cost savings we are getting from making this assumption really worth the complexity we are introducing to users who now have to think about calling refresh_trainable()?

What is the overhead of checking ourselves if we need to refresh?

@blefaudeux
Copy link
Contributor Author

Is this cost savings we are getting from making this assumption really worth the complexity we are introducing to users who now have to think about calling refresh_trainable()?

What is the overhead of checking ourselves if we need to refresh?

I just saw this, I actually just pinged you on another PR because of that ! So to step back a little, the assumption was always there, it's not new from this PR, I never thought about it so it was an oversight (it was also in OSS). The ShardedDDP code was not very flexible on that front, hence this rather big PR to make it easier to update

Now about the cost, one issue is that the partitioning changes, because we only broadcast or optimize the trainable params (for instance HF had a finetuning job with a huge embedding table I think, which was frozen, this should not count for OSS partitioning since there is no corresponding optimizer state). Now if this parameter becomes trainable, we need to repartition, which means change all the flat buffers. Doing this for every step will have a very sizeable speed impact I think, for big models it means traversing the whole graph and checking that nothing changed when compared to before, I can measure that indeed but it would be model size dependent (and a trivial implementation would be sequential - each rank checks - which would not scale so well).

Would

  • a opt-in auto-check
  • this manual refresh for people who opt-out
    be an acceptable solution ?

What do @stas00 @mannatsingh @vreis @mrshenli or @SeanNaren think ? (picking users)

@stas00
Copy link
Contributor

stas00 commented Feb 7, 2021

Let me check if I understood the proposal correctly:

The user doesn't need to do anything special ever, unless they freeze/unfreeze some layers after the fairscale components were initialized..

If I got that right and the cost of getting optimized performance is running an additional refresh_trainable() right after freeze/unfreeze - I think this is totally reasonable.

The alternative solution is to somehow have a flag on the model level that if "dirty" - will automatically trigger a refresh. Such flag would be turned off as soon as fairscale has initiated its machinery. And then fairscale or pytorch (since a lot of it will end up in pytorch) provides a method to freeze/unfreeze layers which besides doing its normal work, will also make the flag "dirty". So as long as the user uses this method to freeze/unfreeze layers they don't have to do anything else. If they choose to do it in their own way, they should invalidate the flag, or it'd be the same as calling refresh_trainable(). I'm just thinking in the long term - whether it'd be a useful feature to have this flag in nn.Module perhaps for some other features that need to detect whether the model hasn't been "tampered with" since it was last looked at.

Would

* a opt-in auto-check 
* this manual refresh for people who opt-out
  be an acceptable solution ?

I think the correct solution here is an opt-out from auto-check followed by manual refresh, that is auto-check should be the default. Because otherwise things will sometimes work, and other times sort of work and give bad results that could be missed. So it's better to be slow but correct by default.

The other more risky approach is not to enable it by default and run the auto-check anyway, but say every 1000 steps, and assert if a change is detected - explaining what needs to be done to make things right. I'm just not sure how to pick the right interval so that it's large enough not to be taxing, yet small enough for it to detect such changes.

@SeanNaren
Copy link

@stas00 brings up really good points, and I have similar understanding!

We'll exclude PL for now, from a PyTorch perspective it would definitely be not recommended to introduce an additional call like this even for speed benefits, as users would preferably not have to diverge from typical freezing/unfreezing specifically for ShardedDDP.

From a PL perspective (and I think HF transformers since you guys define freeze functions right?) I think it's cool since we have control over the freeze logic to some extent.

On average I think refresh_trainable or refresh_partitions is the way to go, since this allows us not to do any large traversals and make it a less frequent op that runs only when necessary. The performance loss from doing any traversal on really large models isn't worth it (we've stayed away from find_used_parameters for now!).

@stas00
Copy link
Contributor

stas00 commented Feb 7, 2021

From a PL perspective (and I think HF transformers since you guys define freeze functions right?)

HF Trainer currently doesn't have freeze functions. Some example scripts like finetune_trainer.py do provide this feature. But these can be added if a need arises.

On average I think refresh_trainable or refresh_partitions is the way to go, since this allows us not to do any large traversals and make it a less frequent op that runs only when necessary. The performance loss from doing any traversal on really large models isn't worth it (we've stayed away from find_used_parameters for now!).

@SeanNaren, I think @blefaudeux is asking how the default should be handled - I'm with you to what you said above, but I don't think this behavior should be the default. Since it could lead to potentially undetectable problems. Especially, since users may not choose to use the framework trainer's freeze functions. For example if they are porting from a different framework and they already have an existing way that works.

On the other hand since the framework's trainer initializes ShardedDDP, it can be left up to the specific trainer to choose what the default behavior is for that framework. So perhaps, there should be no default on the fairscale level, but the policy flag has to be required - which will force the integrator to make a choice and stand behind it. Does it make sense?

detect_model_changes_policy=(always, never, x_steps) (every 10000 steps)

or to closer mimic pytorch's find_unused_parameters, perhaps find_param_unfreeze so it intuitively appears similar?

HF Trainer has find_unused_parameters defined as:

  • False if gradient checkpointing is used, True otherwise.
  • can be overridden by a user via a --ddp_find_unused_parameters flag

So the intention is correctness out of the box over speed.

Do note that pytorch logger.warns the user when find_unused_parameters=True and it detects that it was a wasteful operation when there were no unused parameters to be found. That's why I propose an identical behavior here. Warn the user if they use the default behavior and fairscale detects that it was not needed.

Should pytorch introduce freeze/unfreeze functions with some hooks? And encourage uses to use those over manual modification. In that case fairscale could tap into this function's hook and know whether the model was modified or still intact.

@blefaudeux
Copy link
Contributor Author

I think the correct solution here is an opt-out from auto-check followed by manual refresh, that is auto-check should be the default.

agreed after thinking about it a bit more, the default should be correct no matter what, makes sense to me

@blefaudeux
Copy link
Contributor Author

Thanks @stas00 and @SeanNaren for the comments, very much on point and appreciated. To make sure that this is clear, there's a link with "find unused parameters" (in that we're talking about the graph that we're training, and it's correctness), but it's not the exact same problem that we're trying to solve, in that case a parallel could be in between eager evaluation or static mode. I agree with Stas conclusion here, in that getting this wrong would be pretty subtle for a user (no crash or loss going NaN, it would "just"not optimize what's planned, I've recently seen an unrelated bug of the sort in a framework and it went unseen for a long time), so better err on the side of caution I think (it was also Mandeep's opinion I think). I'll implement that for both OSS and ShardedDDP, since they have the same issue on that front.
I also agree that as frameworks HF and PL can override that or try to find a better way, same for power users, so I would argue that being able to turn that off and manually refresh would be nice to have, it's basically an extra optional API call.

@min-xu-ai
Copy link
Contributor

sorry, late to the party. Just wondering: why didn't pytorch ddp need this? Is it because they use "find_unused_parameter"?

In vissl's case though, I don't think it uses find_unused_parameter but I could be wrong.

@stas00
Copy link
Contributor

stas00 commented Feb 8, 2021

Whatever solution you guys commit to and once it's documented would you kindly ping us so that we could implement this change and document for users how they can further optimize this behavior. Actually probably this can only be done once a new release is made so that we could set the deps correctly.

I highly recommend to give users a warning like pytorch does if a wasteful operation is done and it proves to be unneeded - this would be the easiest way for users to discover that they could do better. I guess the only tricky part would be how to tell the user to override the default when there is a framework trainer which hides the implementation. So if you decide to add this extra signal I'd say the sharded init function should probably have an extra optional arg, that if supplied will be used to instruct the user on how to overcome the default. e.g.:

override_autocheck_note="pass --sharded_ddp_no_auto_refresh=True to disable the auto-check"

and fairscale will provide the first part of the warning, so that part would remain the same.

I hope I'm not over-complicating things.

@blefaudeux blefaudeux marked this pull request as ready for review February 10, 2021 00:57
@blefaudeux
Copy link
Contributor Author

blefaudeux commented Feb 10, 2021

Update on the current status:

  • automatically checking by default if the trainability has changed, or not. If changed, updates shardedDDP and OSS, else not checking but the update can be triggered manually. The check is not that slow actually, imperceptible on a smallish model (40M params)
  • added unit test to cover that, with and without AMP, seems all good
  • not repartitioning on the OSS side, on purpose, because I think that it would overcomplicate things (you would have to move the state around to the new recipients, rewrite all the optimizer states, etc..). It's probably possible to find a corner case for that (ie: if you enable a huge parameter to be trained, and the rank which gets it becomes the bottleneck), but for now I think that it's not worth the extra complexity
  • there are quite a few things to change even without repartitioning, because the communications have to be changed, we typically sync the trainable params around, so if you change them then the coms have to adapt. Since the coms can be batched in buckets, it means that the buckets have to be rewired. Since the buckets hold some of the "state" (for OSS, the model parameters are spread in the buckets, so you cannot just wipe them out), it means that rewiring them needs to be done carefully not to lose any old state. All in all, it means that this is a fairly big PR, sorry about that

@blefaudeux
Copy link
Contributor Author

(the broken test seems to be related to an ssh misconfig or just a broken network for a sec, unrelated to this PR https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1615/workflows/ce4480bb-289c-4a22-9632-e4cf416aaf5d/jobs/7784)

@blefaudeux
Copy link
Contributor Author

blefaudeux commented Feb 10, 2021

alright, now in need to handle differences in between pytorch versions..
edit: so it's actually tied to CUDA versions, the same test works perfectly with cuda 11 :(

@blefaudeux blefaudeux marked this pull request as draft February 10, 2021 23:42
@blefaudeux blefaudeux marked this pull request as ready for review February 11, 2021 01:00
@anj-s anj-s removed their request for review February 11, 2021 18:42
import io
from typing import Any, Callable, Dict, Optional

import torch
from torch._six import container_abcs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this breaks on fbcode, not compatible with all torch versions it seems and not useful I presume given that it's in collections

def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and
`refresh_trainability` is called.
"""

for device, per_rank_params in self.per_device_params.items():
self.buckets[device] = []
# Only wipe the existing buckets if there are none
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a new part / significantly changed, the idea is that the buckets get re-deployed when trainability changed (since we only broadcast trainable params), and when re-deploying them you need to take care not to loose the previous state

# Tensor cannot be really empty, even if its size is meaningless
dummy_sync_tensor = torch.tensor([1], device=self._device)
dummy_sync_tensor = torch.tensor([1], device=self._default_device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_device -> _default_device was just a refactor, no logic change, guess was that _device was not clear what this meant

INPUTS = 2
BATCH_SIZE = 32

def check_parity(amp: bool, accumulate: bool, change_train_graph: bool):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a systematic check for parity with DDP, with AMP/accumulation/change train graph all flipped on and off

@@ -406,3 +406,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
return False
else:
return a == b


def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module, message: str = "") -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor refactor, this was used in several places and copy pastaed

check_same_model_params()

# Check that altering the trainable parameters does not cause DDP and OSS to diverge
if change_train_graph:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new logic change in this test, check that changing trainability is properly taken into account. Commenting out this PRs additions in OSS/ShardedDDP breaks, as expected

@@ -71,6 +78,14 @@ class ShardedDataParallel(nn.Module):
handled. In that case ShardedDDP will raise an exception and suggest to either remove the unused parameters from your model
(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=unused_parameters is helpful)
or set `reduce_buffer_size` to 0

.. warning:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tentatively explain the two options:

  • auto detect (default)
  • manual + explicit refresh() call

@@ -117,14 +134,19 @@ def __init__(
# several optimizers can be present each working on seperate parameter set which is spread across multiple ranks

# - we build an iterator which goes through all the parameters involved globally
all_param_iterator = chain(
*[sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers]
self._all_params = list(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache this list because reused for every step if auto_detect_trainability


# - keep track of the grads which have already been reduced
self._reduced_grads: Dict[OSS, int] = {}
self._reduced_grads_max = {o: len(o.param_to_rank.values()) for o in self.sharded_optimizers}
self._reduced_grads = 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prior to this PR, the trainability graph was kind of baked in these structures. One way to make the comm change more manageable I think is to have all these being completely flat, and handle the partition in a single place

# Optionally check whether the trainable parameters have changed
if self.auto_refresh_trainable:
trainable_mask = list(map(_trainable, self._all_params))
if trainable_mask != self._reference_trainable_mask:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just compare trainability binary masks, any change (one parameter frozen/unfrozen) will trigger an update

self._trainable_param_to_rank = {}
for optim in self.sharded_optimizers:
# OSS may need to change the communication pattern
optim.refresh_trainable()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the OSS broadcast pattern needs to be updated on the fly

for param in filter(lambda x: x.requires_grad, device_params):
self._trainable_param_to_rank[param] = optim.param_to_rank[param]

self._setup_bucket_strategy()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for the buckets (we reduce the grads only..) and the hooks (new hooks could be required)

Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice. If not already, I'd suggest testing it on more than 2 GPUs since there might be corner cases there.

this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded.
Set to 0 to remove all bucketing.
auto_refresh_trainable (bool):
Check whether the parameters trainability (`requires_grad`) has changed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add default value?

automatically.
If `auto_refresh_trainable` is set to `False`, ShardedDDP will not refresh its assumptions with respect to trainable parameters
for every forward pass, in the hope of saving some time. If some parameters are frozen or unfrozen over time, please refresh
ShardedDDP assumptions by calling `refresh_trainable()` just after said change (before the next forward pass).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice doc

def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """

self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

anything to assert (assumptions) up on entering this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, I'll try to think about something, all this is tricky (to me at least) so I could do with more asserts indeed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely not "just to you"! 🤣

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pulling my hairs now on issues with the state dict loading & custom optimizers (turns out I should not write in their state), distributed training can certainly get a tad bit complex...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added an assert, it was a very good call I think, for instance somebody could have had the 'no_sync' context activated, then refresh trainability while forgetting to send the gradients (which would be lost).

@@ -133,67 +133,66 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
torch.cuda.set_device(rank)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@min-xu-ai checking with you, the ddp_parity test runs with cuda_count(), so it should be 4 GPUs on CI. Is that ok ? (sanity checking that I'm not missing anything)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know CI's gpu limit. If it is >2, then it is great. Also, I have seen bugs only shows up in gpu>5 in other cases. (not for oss or shardedDP).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ouch for 5+, I'll test that on fb cluster again. Seems like CI/unit tests is 4 gpus (https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1644/workflows/9e1b0fc9-92dd-4d4e-96de-be775cf5634b/jobs/7982)

Copy link
Contributor Author

@blefaudeux blefaudeux Feb 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking with f250620664 and f250623267 [FB only]

@blefaudeux blefaudeux merged commit 13445c5 into master Feb 12, 2021
@blefaudeux blefaudeux deleted the shardedddp_nosync_fix branch February 12, 2021 00:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ShardedDDP][feature] Freeze/unfreeze parts of the graph on the fly
6 participants