Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend auto shard capabilities to work around torch.fx edge cases. #817

Merged
merged 1 commit into from
Oct 22, 2021

Conversation

EugenHotaj
Copy link
Contributor

@EugenHotaj EugenHotaj commented Oct 19, 2021

auto_shard.py currently uses torch.fx to create a symbolic DAG of
operations and linearizes that DAG into an nn.Sequential so it can later
be used for model offloading. This works in most cases but runs into
issues for certain eager mode features, such as dynamic conditionals,
shape-dependent computation, etc.

This PR extends auto_shard.py to first run a preprocessing step which wraps
any nn.Module which cannot be traced through. It adds a test for dynamic
conditionals and updates existing failing test code.

There are some immediate extensions to this approach which are marked as
TODO in the code.

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 19, 2021
@EugenHotaj EugenHotaj requested a review from anj-s October 19, 2021 19:27
sharded_model = shard_model(model, 3)
# TODO(ehotaj): There might be a bug in our split code because we shard the
# model into 10 shards even though we specify 3 shards above.
assert len(sharded_model) == 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you print out the original model and sharded model?

Copy link
Contributor Author

@EugenHotaj EugenHotaj Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the full model:

BranchedNetwork(
  (net): ModuleList(
    (0): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
    (1): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
    (2): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
    (3): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
    (4): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
    (5): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
    (6): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
    (7): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
    (8): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
    (9): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
  )
)

And the sharded model:

[GraphModule(
  (net): Module(
    (0): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
  )
), GraphModule(
  (net): Module(
    (1): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
  )
), GraphModule(
  (net): Module(
    (2): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
  )
), GraphModule(
  (net): Module(
    (3): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
  )
), GraphModule(
  (net): Module(
    (4): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
  )
), GraphModule(
  (net): Module(
    (5): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
  )
), GraphModule(
  (net): Module(
    (6): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
  )
), GraphModule(
  (net): Module(
    (7): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
  )
), GraphModule(
  (net): Module(
    (8): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
  )
), GraphModule(
  (net): Module(
    (9): Branch(
      (left): Linear(in_features=10, out_features=10, bias=True)
      (right): Linear(in_features=10, out_features=10, bias=True)
    )
  )
)]

Looks reasonable to me but I'm not sure why the shard_count is not being respected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why the shard_count is not being respected.

BTW, this also happens without the changes in this PR, i.e. if we just use torch.fx.symbolic_trace directly, so I think something might be going wrong in the _split_nodes logic.

(I could try looking into it in a follow up PR).

auto_shard.py currently uses torch.fx to create a symbolic DAG of
operations and linearizes that DAG into an nn.Sequential so it can later
be used for model offloading. This works in most cases but runs into
issues for certain eager mode features, such as dynamic conditionals,
shape-dependent computation, etc.

This PR extends auto_shard.py to first run a preprocessing step which wraps
any nn.Module which cannot be traced through. It adds a test for dynamic
conditionals and updates existing failing test code.

There are some immediate extensions to this approach which are marked as
TODO in the code.
@anj-s anj-s merged commit 7bdf50a into main Oct 22, 2021
@anj-s anj-s deleted the autoshard branch October 22, 2021 12:24
vtantia pushed a commit that referenced this pull request Oct 29, 2021
)

auto_shard.py currently uses torch.fx to create a symbolic DAG of
operations and linearizes that DAG into an nn.Sequential so it can later
be used for model offloading. This works in most cases but runs into
issues for certain eager mode features, such as dynamic conditionals,
shape-dependent computation, etc.

This PR extends auto_shard.py to first run a preprocessing step which wraps
any nn.Module which cannot be traced through. It adds a test for dynamic
conditionals and updates existing failing test code.

There are some immediate extensions to this approach which are marked as
TODO in the code.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants