Skip to content

Commit

Permalink
Fill value for non-tensor inputs during shape prop (#933)
Browse files Browse the repository at this point in the history
Fixes #932
  • Loading branch information
kwen2501 committed Jan 26, 2024
1 parent 099f140 commit 3632106
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions pippy/IR.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,34 +1004,6 @@ def move_param_to_callee(

split.delete_all_unused_submodules()

# Users want the first pipeline stage to accept kwargs if the original
# program does. This is controlled by the `_codegen` field of the graph,
# so we make a copy here. Note: we only want the input spec and not the
# output spec, because the output spec is for the last stage. Maybe a
# TODO? Not sure yet.
submod0 = list(split.children())[0]
model_sign = signature(traced.forward)
model_num_args = len(model_sign.parameters)
submod0_sign = signature(submod0.forward)
submod0_num_args = len(submod0_sign.parameters)
if model_num_args != submod0_num_args:
# We don't change the signature of the first stage if it takes
# different number of args than original model
logger.info(
f"Original model takes {model_num_args} args but the first pipeline stage takes {submod0_num_args}. "
"Please provide args to respective pipeline stages."
)
else:
# Support kwargs for the first stage
submod0.graph._codegen = copy.deepcopy(traced.graph._codegen)
# `_replace` is actually not "private" or internal. based on this doc:
# To prevent conflicts with field names, the method and attribute names
# start with an underscore
submod0.graph._codegen.pytree_info = (
submod0.graph._codegen.pytree_info._replace(out_spec=None)
)
submod0.recompile()

split.graph.lint()
split.recompile()

Expand Down Expand Up @@ -1176,6 +1148,33 @@ def from_tracing(
f"{node.meta['example_value'] if 'example_value' in node.meta else 'None'}",
)

# Users want the first pipeline stage to accept kwargs if the original
# program does. This is controlled by the `_codegen` field of the graph,
# so we make a copy here. Note: we only want the input spec and not the
# output spec, because the output spec is for the last stage. Maybe a
# TODO? Not sure yet.
split = pipe.split_gm
submod0 = list(split.children())[0]
submod0_sign = signature(submod0.forward)
model_sign = signature(traced.forward)
if len(model_sign.parameters) != len(submod0_sign.parameters):
# We don't change the signature of the first stage if it takes
# different number of args than original model
logger.info(
f"Original model takes {len(model_sign.parameters)} args but the first pipeline stage takes {len(submod0_sign.parameters)}. "
"Please provide args to respective pipeline stages."
)
else:
# Support kwargs for the first stage
submod0.graph._codegen = copy.deepcopy(traced.graph._codegen)
# `_replace` is actually not "private" or internal. based on this doc:
# To prevent conflicts with field names, the method and attribute names
# start with an underscore
submod0.graph._codegen.pytree_info = (
submod0.graph._codegen.pytree_info._replace(out_spec=None)
)
submod0.recompile()

return pipe

def __str__(self):
Expand Down Expand Up @@ -1284,8 +1283,12 @@ def __init__(
self.stop_prop = False

def run(self):
# Prepare input from node.meta, which will be filled during tracing if
# input is a tensor. For non-tensor inputs, e.g. constants, its value
# would have been burned into the program, so we use an arbitrary value
# here (None).
inp = tuple(
node.meta["val"]
node.meta["val"] if "val" in node.meta else None
for node in self.module.graph.nodes
if node.op == "placeholder"
)
Expand Down

0 comments on commit 3632106

Please sign in to comment.