diff --git a/pippy/IR.py b/pippy/IR.py index 48f32fcac..23c0db526 100644 --- a/pippy/IR.py +++ b/pippy/IR.py @@ -1004,34 +1004,6 @@ def move_param_to_callee( split.delete_all_unused_submodules() - # Users want the first pipeline stage to accept kwargs if the original - # program does. This is controlled by the `_codegen` field of the graph, - # so we make a copy here. Note: we only want the input spec and not the - # output spec, because the output spec is for the last stage. Maybe a - # TODO? Not sure yet. - submod0 = list(split.children())[0] - model_sign = signature(traced.forward) - model_num_args = len(model_sign.parameters) - submod0_sign = signature(submod0.forward) - submod0_num_args = len(submod0_sign.parameters) - if model_num_args != submod0_num_args: - # We don't change the signature of the first stage if it takes - # different number of args than original model - logger.info( - f"Original model takes {model_num_args} args but the first pipeline stage takes {submod0_num_args}. " - "Please provide args to respective pipeline stages." - ) - else: - # Support kwargs for the first stage - submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) - # `_replace` is actually not "private" or internal. based on this doc: - # To prevent conflicts with field names, the method and attribute names - # start with an underscore - submod0.graph._codegen.pytree_info = ( - submod0.graph._codegen.pytree_info._replace(out_spec=None) - ) - submod0.recompile() - split.graph.lint() split.recompile() @@ -1176,6 +1148,33 @@ def from_tracing( f"{node.meta['example_value'] if 'example_value' in node.meta else 'None'}", ) + # Users want the first pipeline stage to accept kwargs if the original + # program does. This is controlled by the `_codegen` field of the graph, + # so we make a copy here. Note: we only want the input spec and not the + # output spec, because the output spec is for the last stage. Maybe a + # TODO? Not sure yet. + split = pipe.split_gm + submod0 = list(split.children())[0] + submod0_sign = signature(submod0.forward) + model_sign = signature(traced.forward) + if len(model_sign.parameters) != len(submod0_sign.parameters): + # We don't change the signature of the first stage if it takes + # different number of args than original model + logger.info( + f"Original model takes {len(model_sign.parameters)} args but the first pipeline stage takes {len(submod0_sign.parameters)}. " + "Please provide args to respective pipeline stages." + ) + else: + # Support kwargs for the first stage + submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) + # `_replace` is actually not "private" or internal. based on this doc: + # To prevent conflicts with field names, the method and attribute names + # start with an underscore + submod0.graph._codegen.pytree_info = ( + submod0.graph._codegen.pytree_info._replace(out_spec=None) + ) + submod0.recompile() + return pipe def __str__(self): @@ -1284,8 +1283,12 @@ def __init__( self.stop_prop = False def run(self): + # Prepare input from node.meta, which will be filled during tracing if + # input is a tensor. For non-tensor inputs, e.g. constants, its value + # would have been burned into the program, so we use an arbitrary value + # here (None). inp = tuple( - node.meta["val"] + node.meta["val"] if "val" in node.meta else None for node in self.module.graph.nodes if node.op == "placeholder" )