Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the List/Tuple Construct/Unpack operation for TorchModuleGraph #2609

Merged
merged 5 commits into from
Jul 24, 2020

Conversation

zheng-ningxin
Copy link
Contributor

Support the list/tuple construct/unpack operation for the TorchModuleGraph.
Fix the bug mentioned in #2581.

In the original version, we take the list/tuple construct/unpack operation nodes as unimportant nodes and merge them with the adjacent important nodes. However, merging the unpack nodes will lead to a graph construct error.

Ningxin added 4 commits June 28, 2020 09:03
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
@zheng-ningxin zheng-ningxin changed the title Support the List/Tuple Construct/Unpack operation for TorchModuleGraph(Bugfix for issue 2581) Support the List/Tuple Construct/Unpack operation for TorchModuleGraph Jun 29, 2020
@zheng-ningxin zheng-ningxin marked this pull request as draft June 29, 2020 08:03
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
@zheng-ningxin zheng-ningxin reopened this Jul 1, 2020
@zheng-ningxin zheng-ningxin marked this pull request as ready for review July 1, 2020 08:27
@ultmaster ultmaster requested a review from chicm-ms July 17, 2020 09:31
@scarlett2018 scarlett2018 mentioned this pull request Jul 22, 2020
66 tasks
@QuanluZhang QuanluZhang self-requested a review July 22, 2020 10:53
@@ -199,6 +203,8 @@ def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None
All the inputs of this node, each element is debugName of one input
outputs: list of str
All the outputs of this node, each element is debugName of one output
key_node: torch._C.Node
The key node of this NodePyGroup.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of key node?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Key nodes are the nodes that should not be merged into other nodes. In the past, we only take the aten:: nodes as the important(key) nodes. In this pr, we also take the list/tuple unpack nodes as the key nodes.

# the nodes that start with 'aten' are key function
# nodes
return True
if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not construct type here?

if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]:
# We cannot merge the List/Tuple
# Construct/Unpack func into other nodes, else it
# may lead to a graph construction error.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this error like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take the shufflenet as an example:
#2581

for node in self.nodes_py.nodes_op:
if node.op_type in [TUPLE_UNPACK_KIND, LIST_UNPACK_KIND]:
unpack_cpp = node.key_node
last_cpp = list(unpack_cpp.inputs())[0].node()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index is 0, why call it last?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last node actually refers to the previous(last) visited node. In most scenarios, this last_cpp is the corresponding construct node of the tuple/list.

# or list manunally.
_logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp))
_logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp))
assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a little confused about this assert, i think the main reason is i don't understand what is last_cpp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants