Skip to content

Commit

Permalink
Add example for unrolling iterative blocks (#920)
Browse files Browse the repository at this point in the history
## Description

Demonstrate PiPPy's functionality in unrolling iterative blocks.

For details, please see
[README](https://github.com/pytorch/PiPPy/tree/unroll_example/examples/unrolling).

Many thanks to @mortzur 's inspiration!
  • Loading branch information
kwen2501 committed Jan 2, 2024
1 parent 169892c commit bb90773
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 0 deletions.
93 changes: 93 additions & 0 deletions examples/unrolling/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
## What does this example do?

This is a synthetic example used to demonstrate PiPPy's functionality in unrolling iterative blocks in a model.

We create a model that runs an iteration block in a for loop:
```python
class IterationBlock(torch.nn.Module):
def __init__(self, d_hid):
super().__init__()
self.lin = torch.nn.Linear(d_hid, d_hid)

def forward(self, x):
x = self.lin(x)
x = torch.relu(x)
return x


class IterativeNetwork(torch.nn.Module):
def __init__(self, d_hid, num_iters):
super().__init__()
self.num_iters = num_iters
self.iter_block = IterationBlock(d_hid)
# 10 output classes
self.output_proj = torch.nn.Linear(d_hid, 10)

def forward(self, x):
for i in range(self.num_iters):
x = self.iter_block(x)
return self.output_proj(x)
```

If we annotate the model as follows, we will create a pipeline stage per
iteration block:

```python
# Add a split point after each iter_block
annotate_split_points(
model,
{"iter_block": PipeSplitWrapper.SplitPoint.END},
)
```

That is, PiPPy would create a split point every time it sees "self.iter_block".

Run it with 4 ranks:
```
$ torchrun --nproc-per-node 4 pippy_unroll.py
```

Print-out of the pipe:
```
************************************* pipe *************************************
GraphModule(
(submod_0): PipeStageModule(
(L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True)
)
(submod_1): PipeStageModule(
(L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True)
)
(submod_2): PipeStageModule(
(L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True)
)
(submod_3): PipeStageModule(
(L__self___output_proj): Linear(in_features=512, out_features=10, bias=True)
)
)
def forward(self, arg0):
submod_0 = self.submod_0(arg0); arg0 = None
submod_1 = self.submod_1(submod_0); submod_0 = None
submod_2 = self.submod_2(submod_1); submod_1 = None
submod_3 = self.submod_3(submod_2); submod_2 = None
return [submod_3]
```
We can see 4 stages as expected (3 iterations plus 1 output projection).

If we print one of the stages, we can see that it contains the code of one iteration:
```
*********************************** submod0 ************************************
PipeStageModule(
(L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True)
)
def forward(self, l_x_):
l__self___iter_block_mod_lin = self.L__self___iter_block_mod_lin(l_x_); l_x_ = None
relu = torch.relu(l__self___iter_block_mod_lin); l__self___iter_block_mod_lin = None
return relu
```

## How can this functionality help?
Increase throughput of your model.

Imagine your for loop needs to iterate on the data for `n` times, and it takes time `t` to process 1 sample (yielding a throughput of `1/t`). If we were to unroll the for loop onto `n` devices, then we can push `n` microbatches into the pipeline, each microbatch containing 1 sample. Then at any timeslot, the pipeline is processing `n` samples, yielding a throughput of `n/t`.
98 changes: 98 additions & 0 deletions examples/unrolling/pippy_unroll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Minimal effort to run this code:
# $ torchrun --nproc-per-node 4 pippy_unroll.py

import os
import torch
import torch.distributed as dist

from pippy.IR import annotate_split_points, Pipe, PipeSplitWrapper
from pippy.PipelineStage import PipelineStage


class IterationBlock(torch.nn.Module):
def __init__(self, d_hid):
super().__init__()
self.lin = torch.nn.Linear(d_hid, d_hid)

def forward(self, x):
x = self.lin(x)
x = torch.relu(x)
return x


class IterativeNetwork(torch.nn.Module):
def __init__(self, d_hid, num_iters):
super().__init__()
self.num_iters = num_iters
self.iter_block = IterationBlock(d_hid)
# 10 output classes
self.output_proj = torch.nn.Linear(d_hid, 10)

def forward(self, x):
for i in range(self.num_iters):
x = self.iter_block(x)
return self.output_proj(x)


# We are using `torchrun` to run this example with multiple processes.
# `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`.
torch.manual_seed(0)
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

# Figure out device to use
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
else:
device = torch.device("cpu")

# Create the model
d_hid = 512
# (n-1) iterations + 1 output projection
num_iters = world_size - 1
model = IterativeNetwork(d_hid, num_iters).to(device)

# Add a split point after each iter_block
annotate_split_points(
model,
{"iter_block": PipeSplitWrapper.SplitPoint.END},
)

batch_size = 32
example_input = torch.randn(batch_size, d_hid, device=device)
chunks = world_size

pipe = Pipe.from_tracing(model, chunks, example_args=(example_input,))

if rank == 0:
print(" pipe ".center(80, "*"))
print(pipe)
print(" submod0 ".center(80, "*"))
print(pipe.split_gm.submod_0)

# Initialize distributed environment
dist.init_process_group(rank=rank, world_size=world_size)

# Pipeline stage is our main pipeline runtime. It takes in the pipe object,
# the rank of this process, and the device.
stage = PipelineStage(pipe, rank, device)

# Input data
x = torch.randn(batch_size, d_hid, device=device)

# Run the pipeline with input `x`. Divide the batch into n micro-batches
# and run them in parallel on the pipeline
if rank == 0:
stage(x)
elif rank == world_size - 1:
output = stage()
else:
stage()

if rank == world_size - 1:
# Run the original code and get the output for comparison
reference_output = model(x)
# Compare numerics of pipeline and original model
torch.testing.assert_close(output, reference_output)
print(" Pipeline parallel model ran successfully! ".center(80, "*"))

0 comments on commit bb90773

Please sign in to comment.