-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend auto shard capabilities to work around torch.fx edge cases. #817
Conversation
sharded_model = shard_model(model, 3) | ||
# TODO(ehotaj): There might be a bug in our split code because we shard the | ||
# model into 10 shards even though we specify 3 shards above. | ||
assert len(sharded_model) == 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you print out the original model and sharded model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the full model:
BranchedNetwork(
(net): ModuleList(
(0): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
(1): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
(2): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
(3): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
(4): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
(5): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
(6): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
(7): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
(8): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
(9): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
)
)
And the sharded model:
[GraphModule(
(net): Module(
(0): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
)
), GraphModule(
(net): Module(
(1): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
)
), GraphModule(
(net): Module(
(2): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
)
), GraphModule(
(net): Module(
(3): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
)
), GraphModule(
(net): Module(
(4): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
)
), GraphModule(
(net): Module(
(5): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
)
), GraphModule(
(net): Module(
(6): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
)
), GraphModule(
(net): Module(
(7): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
)
), GraphModule(
(net): Module(
(8): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
)
), GraphModule(
(net): Module(
(9): Branch(
(left): Linear(in_features=10, out_features=10, bias=True)
(right): Linear(in_features=10, out_features=10, bias=True)
)
)
)]
Looks reasonable to me but I'm not sure why the shard_count
is not being respected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why the shard_count is not being respected.
BTW, this also happens without the changes in this PR, i.e. if we just use torch.fx.symbolic_trace
directly, so I think something might be going wrong in the _split_nodes
logic.
(I could try looking into it in a follow up PR).
auto_shard.py currently uses torch.fx to create a symbolic DAG of operations and linearizes that DAG into an nn.Sequential so it can later be used for model offloading. This works in most cases but runs into issues for certain eager mode features, such as dynamic conditionals, shape-dependent computation, etc. This PR extends auto_shard.py to first run a preprocessing step which wraps any nn.Module which cannot be traced through. It adds a test for dynamic conditionals and updates existing failing test code. There are some immediate extensions to this approach which are marked as TODO in the code.
) auto_shard.py currently uses torch.fx to create a symbolic DAG of operations and linearizes that DAG into an nn.Sequential so it can later be used for model offloading. This works in most cases but runs into issues for certain eager mode features, such as dynamic conditionals, shape-dependent computation, etc. This PR extends auto_shard.py to first run a preprocessing step which wraps any nn.Module which cannot be traced through. It adds a test for dynamic conditionals and updates existing failing test code. There are some immediate extensions to this approach which are marked as TODO in the code.
auto_shard.py currently uses torch.fx to create a symbolic DAG of
operations and linearizes that DAG into an nn.Sequential so it can later
be used for model offloading. This works in most cases but runs into
issues for certain eager mode features, such as dynamic conditionals,
shape-dependent computation, etc.
This PR extends auto_shard.py to first run a preprocessing step which wraps
any nn.Module which cannot be traced through. It adds a test for dynamic
conditionals and updates existing failing test code.
There are some immediate extensions to this approach which are marked as
TODO in the code.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.