Skip to content

Commit

Permalink
Re-support kwargs at run time (#929)
Browse files Browse the repository at this point in the history
## Description

Implements #928 

Users want the first pipeline stage to accept kwargs if the original
program does.
This is controlled by the `_codegen` field of the graph as @angelayi
suggests, so we make a copy from the traced program to submod0.


## Feature/Issue validation/testing

Added kwargs in test_fwd.py.
Also changed a few HF examples to directly kwargs.
  • Loading branch information
kwen2501 committed Jan 19, 2024
1 parent e9e2d5f commit 5025063
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 48 deletions.
6 changes: 3 additions & 3 deletions examples/hf/pippy_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def run(args):
# Input configs
example_inputs = generate_inputs_for_model(
model_class, albert, model_name, args.batch_size, args.device)
input_ids = example_inputs["input_ids"]

# Annotate split points
add_split_points(albert, args.world_size)
Expand All @@ -55,7 +54,8 @@ def run(args):
albert_pipe = Pipe.from_tracing(
albert,
num_chunks=args.chunks,
example_args=(input_ids, ),
example_args=(),
example_kwargs=example_inputs,
)
nstages = len(list(albert_pipe.split_gm.children()))
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
Expand All @@ -72,7 +72,7 @@ def run(args):

# Run
if args.rank == 0:
stage(input_ids)
stage(**example_inputs)
elif args.rank == args.world_size - 1:
out = stage()
else:
Expand Down
6 changes: 3 additions & 3 deletions examples/hf/pippy_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def run(args):
# Input configs
example_inputs = generate_inputs_for_model(
model_class, bart, model_name, args.batch_size, args.device)
input_ids = example_inputs["input_ids"]

# Annotate split points
add_split_points(bart, args.world_size)
Expand All @@ -52,7 +51,8 @@ def run(args):
bart_pipe = Pipe.from_tracing(
bart,
num_chunks=args.chunks,
example_args=(input_ids, ),
example_args=(),
example_kwargs=example_inputs,
)
nstages = len(list(bart_pipe.split_gm.children()))
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
Expand All @@ -69,7 +69,7 @@ def run(args):

# Run
if args.rank == 0:
stage(input_ids)
stage(**example_inputs)
elif args.rank == args.world_size - 1:
out = stage()
else:
Expand Down
6 changes: 3 additions & 3 deletions examples/hf/pippy_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def run(args):
# Input configs
example_inputs = generate_inputs_for_model(
model_class, bert, model_name, args.batch_size, args.device)
input_ids = example_inputs["input_ids"]

# Annotate split points
add_split_points(bert, args.world_size)
Expand All @@ -52,7 +51,8 @@ def run(args):
bert_pipe = Pipe.from_tracing(
bert,
num_chunks=args.chunks,
example_args=(input_ids, ),
example_args=(),
example_kwargs=example_inputs,
)
nstages = len(list(bert_pipe.split_gm.children()))
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
Expand All @@ -69,7 +69,7 @@ def run(args):

# Run
if args.rank == 0:
stage(input_ids)
stage(**example_inputs)
elif args.rank == args.world_size - 1:
out = stage()
else:
Expand Down
6 changes: 3 additions & 3 deletions examples/hf/pippy_camemBert.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def run(args):
# Input configs
example_inputs = generate_inputs_for_model(
model_class, camembert, model_name, args.batch_size, args.device)
input_ids = example_inputs["input_ids"]

# Annotate split points
add_split_points(camembert, args.world_size)
Expand All @@ -52,7 +51,8 @@ def run(args):
camembert_pipe = Pipe.from_tracing(
camembert,
num_chunks=args.chunks,
example_args=(input_ids, ),
example_args=(),
example_kwargs=example_inputs,
)
nstages = len(list(camembert_pipe.split_gm.children()))
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
Expand All @@ -69,7 +69,7 @@ def run(args):

# Run
if args.rank == 0:
stage(input_ids)
stage(**example_inputs)
elif args.rank == args.world_size - 1:
out = stage()
else:
Expand Down
6 changes: 3 additions & 3 deletions examples/hf/pippy_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def run(args):
# Input configs
example_inputs = generate_inputs_for_model(
model_class, gpt2, model_name, args.batch_size, args.device)
input_ids = example_inputs["input_ids"]

# Annotate split points
add_split_points(gpt2, args.world_size)
Expand All @@ -61,7 +60,8 @@ def run(args):
gpt2_pipe = Pipe.from_tracing(
gpt2,
num_chunks=args.chunks,
example_args=(input_ids, ),
example_args=(),
example_kwargs=example_inputs,
)
assert len(list(gpt2_pipe.split_gm.children())) == args.world_size
if args.rank == 0:
Expand All @@ -77,7 +77,7 @@ def run(args):

# Run
if args.rank == 0:
stage(input_ids)
stage(**example_inputs)
elif args.rank == args.world_size - 1:
out = stage()
else:
Expand Down
6 changes: 3 additions & 3 deletions examples/hf/pippy_gptNeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def run(args):
# Input configs
example_inputs = generate_inputs_for_model(
model_class, gptneo, model_name, args.batch_size, args.device)
input_ids = example_inputs["input_ids"]

# Annotate split points
add_split_points(gptneo, args.world_size)
Expand All @@ -52,7 +51,8 @@ def run(args):
gptneo_pipe = Pipe.from_tracing(
gptneo,
num_chunks=args.chunks,
example_args=(input_ids, ),
example_args=(),
example_kwargs=example_inputs,
)
nstages = len(list(gptneo_pipe.split_gm.children()))
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
Expand All @@ -69,7 +69,7 @@ def run(args):

# Run
if args.rank == 0:
stage(input_ids)
stage(**example_inputs)
elif args.rank == args.world_size - 1:
out = stage()
else:
Expand Down
6 changes: 3 additions & 3 deletions examples/hf/pippy_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def run(args):
# Input configs
example_inputs = generate_inputs_for_model(
model_class, opt, model_name, args.batch_size, args.device)
input_ids = example_inputs["input_ids"]

# Annotate split points
add_split_points(opt, args.world_size)
Expand All @@ -52,7 +51,8 @@ def run(args):
opt_pipe = Pipe.from_tracing(
opt,
num_chunks=args.chunks,
example_args=(input_ids, ),
example_args=(),
example_kwargs=example_inputs,
)
nstages = len(list(opt_pipe.split_gm.children()))
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
Expand All @@ -69,7 +69,7 @@ def run(args):

# Run
if args.rank == 0:
stage(input_ids)
stage(**example_inputs)
elif args.rank == args.world_size - 1:
out = stage()
else:
Expand Down
32 changes: 30 additions & 2 deletions pippy/IR.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import operator
from enum import Enum
from inspect import Parameter, signature, Signature
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -655,8 +656,6 @@ def throw(self, *args, **kwargs):
def forward(self, *args, **kwargs):
executor_args = args
if len(kwargs) > 0:
from inspect import Parameter, Signature

parameters = []
for node in self.split_gm.graph.nodes:
if node.op == "placeholder":
Expand Down Expand Up @@ -1005,6 +1004,34 @@ def move_param_to_callee(

split.delete_all_unused_submodules()

# Users want the first pipeline stage to accept kwargs if the original
# program does. This is controlled by the `_codegen` field of the graph,
# so we make a copy here. Note: we only want the input spec and not the
# output spec, because the output spec is for the last stage. Maybe a
# TODO? Not sure yet.
submod0 = list(split.children())[0]
model_sign = signature(traced.forward)
model_num_args = len(model_sign.parameters)
submod0_sign = signature(submod0.forward)
submod0_num_args = len(submod0_sign.parameters)
if model_num_args != submod0_num_args:
# We don't change the signature of the first stage if it takes
# different number of args than original model
logger.info(
f"Original model takes {model_num_args} args but the first pipeline stage takes {submod0_num_args}. "
"Please provide args to respective pipeline stages."
)
else:
# Support kwargs for the first stage
submod0.graph._codegen = copy.deepcopy(traced.graph._codegen)
# `_replace` is actually not "private" or internal. based on this doc:
# To prevent conflicts with field names, the method and attribute names
# start with an underscore
submod0.graph._codegen.pytree_info = (
submod0.graph._codegen.pytree_info._replace(out_spec=None)
)
submod0.recompile()

split.graph.lint()
split.recompile()

Expand Down Expand Up @@ -1071,6 +1098,7 @@ def _trace_with_export(
example_kwargs,
constraints,
)
logger.debug(f"Traced model: {traced}")
if split_policy is not None:
traced = split_policy(traced)
finally:
Expand Down
56 changes: 35 additions & 21 deletions pippy/PipelineStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.distributed as dist
import torch.fx as fx
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.node import map_aggregate, map_arg
from torch.nn.parallel import DistributedDataParallel

from pippy.backward import stage_backward
Expand Down Expand Up @@ -47,6 +48,10 @@ class StageArgPlaceholder:
pass


class StageKwargPlaceholder:
pass


class PipelineStage(torch.nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -269,14 +274,15 @@ def create_recv_tensor(

# `args` is a Tuple, hence we will have:
# Tuple[RecvInfo]
args_recv_info = fx.node.map_arg(self.node.args, create_recv_tensor)
args_recv_info = map_arg(self.node.args, create_recv_tensor)

# `kwargs` is a Dict, hence we will have:
# Dict[keyword, RecvInfo]
kwargs_recv_info = fx.node.map_arg(self.node.kwargs, create_recv_tensor)
kwargs_recv_info = map_arg(self.node.kwargs, create_recv_tensor)

logger.info(
f"[{self.group_rank}] " f"Activation recv info: {args_recv_info}"
f"[{self.group_rank}] "
f"Activation recv / args info: {args_recv_info}"
)
return args_recv_info, kwargs_recv_info

Expand Down Expand Up @@ -370,9 +376,9 @@ def map_recv_to_send(a):
grad_send_info.append(None)
return None

fx.node.map_aggregate(args_recv_info, map_recv_to_send)
map_aggregate(args_recv_info, map_recv_to_send)

fx.node.map_aggregate(kwargs_recv_info, map_recv_to_send)
map_aggregate(kwargs_recv_info, map_recv_to_send)

logger.info(f"[{self.group_rank}] " f"Grad send info: {grad_send_info}")
return grad_send_info
Expand Down Expand Up @@ -422,35 +428,43 @@ def _recv_and_fill_inputs(

act_recv = self.recv_tensor_fn(recv_reqs)

chunk_args_list: List = []
if self.args_split:
chunk_args = self.args_split[chunk]
chunk_args_list = list(chunk_args)

def recv_args(info):
if isinstance(info, RecvInfo):
# This is an activation to receive
return act_recv(info)
else:
return chunk_args_list.pop(0) # type: ignore[has-type]
# This is a pass-in argument
if len(chunk_args_list):
return chunk_args_list.pop(0) # type: ignore[has-type]
else:
# kwargs were treated as args in graph phase. That's why
# there are extra placeholders here. We mark them and filter
# them out later.
return StageKwargPlaceholder()

composite_args = fx.node.map_aggregate(
composite_args = map_aggregate(
self.args_recv_info[chunk],
recv_args,
)
# Filter out kwarg placeholders
composite_args = tuple(
x
for x in composite_args
if not isinstance(x, StageKwargPlaceholder)
)

# Middle stages won't have incoming activations in kwargs form. So if
# kwargs_split is not empty, it must be model inputs for stage 0. We
# hence pass it as is to the interal submodule, without performing
# `recv_args` on it.
composite_kwargs: Dict = {}
if self.kwargs_split:
chunk_kwargs = self.kwargs_split[chunk]

def recv_kwargs(info):
if isinstance(info, RecvInfo):
return act_recv(info)
else:
k = next(iter(chunk_kwargs)) # type: ignore[has-type]
return chunk_kwargs.pop(k) # type: ignore[has-type]

composite_kwargs = fx.node.map_aggregate(
self.kwargs_recv_info[chunk],
recv_kwargs,
)
composite_kwargs = self.kwargs_split[chunk]

# Wait for all recvs to finish
for work in recv_reqs:
Expand Down Expand Up @@ -496,7 +510,7 @@ def _recv_grads(
recv_grad = self.recv_tensor_fn(grad_recv_reqs)

# Receive gradients
grads = fx.node.map_aggregate(
grads = map_aggregate(
self.grad_recv_info[bwd_chunk],
recv_grad,
)
Expand Down
Loading

0 comments on commit 5025063

Please sign in to comment.