-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example for unrolling iterative blocks (#920)
## Description Demonstrate PiPPy's functionality in unrolling iterative blocks. For details, please see [README](https://github.com/pytorch/PiPPy/tree/unroll_example/examples/unrolling). Many thanks to @mortzur 's inspiration!
- Loading branch information
Showing
2 changed files
with
191 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
## What does this example do? | ||
|
||
This is a synthetic example used to demonstrate PiPPy's functionality in unrolling iterative blocks in a model. | ||
|
||
We create a model that runs an iteration block in a for loop: | ||
```python | ||
class IterationBlock(torch.nn.Module): | ||
def __init__(self, d_hid): | ||
super().__init__() | ||
self.lin = torch.nn.Linear(d_hid, d_hid) | ||
|
||
def forward(self, x): | ||
x = self.lin(x) | ||
x = torch.relu(x) | ||
return x | ||
|
||
|
||
class IterativeNetwork(torch.nn.Module): | ||
def __init__(self, d_hid, num_iters): | ||
super().__init__() | ||
self.num_iters = num_iters | ||
self.iter_block = IterationBlock(d_hid) | ||
# 10 output classes | ||
self.output_proj = torch.nn.Linear(d_hid, 10) | ||
|
||
def forward(self, x): | ||
for i in range(self.num_iters): | ||
x = self.iter_block(x) | ||
return self.output_proj(x) | ||
``` | ||
|
||
If we annotate the model as follows, we will create a pipeline stage per | ||
iteration block: | ||
|
||
```python | ||
# Add a split point after each iter_block | ||
annotate_split_points( | ||
model, | ||
{"iter_block": PipeSplitWrapper.SplitPoint.END}, | ||
) | ||
``` | ||
|
||
That is, PiPPy would create a split point every time it sees "self.iter_block". | ||
|
||
Run it with 4 ranks: | ||
``` | ||
$ torchrun --nproc-per-node 4 pippy_unroll.py | ||
``` | ||
|
||
Print-out of the pipe: | ||
``` | ||
************************************* pipe ************************************* | ||
GraphModule( | ||
(submod_0): PipeStageModule( | ||
(L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) | ||
) | ||
(submod_1): PipeStageModule( | ||
(L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) | ||
) | ||
(submod_2): PipeStageModule( | ||
(L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) | ||
) | ||
(submod_3): PipeStageModule( | ||
(L__self___output_proj): Linear(in_features=512, out_features=10, bias=True) | ||
) | ||
) | ||
def forward(self, arg0): | ||
submod_0 = self.submod_0(arg0); arg0 = None | ||
submod_1 = self.submod_1(submod_0); submod_0 = None | ||
submod_2 = self.submod_2(submod_1); submod_1 = None | ||
submod_3 = self.submod_3(submod_2); submod_2 = None | ||
return [submod_3] | ||
``` | ||
We can see 4 stages as expected (3 iterations plus 1 output projection). | ||
|
||
If we print one of the stages, we can see that it contains the code of one iteration: | ||
``` | ||
*********************************** submod0 ************************************ | ||
PipeStageModule( | ||
(L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) | ||
) | ||
def forward(self, l_x_): | ||
l__self___iter_block_mod_lin = self.L__self___iter_block_mod_lin(l_x_); l_x_ = None | ||
relu = torch.relu(l__self___iter_block_mod_lin); l__self___iter_block_mod_lin = None | ||
return relu | ||
``` | ||
|
||
## How can this functionality help? | ||
Increase throughput of your model. | ||
|
||
Imagine your for loop needs to iterate on the data for `n` times, and it takes time `t` to process 1 sample (yielding a throughput of `1/t`). If we were to unroll the for loop onto `n` devices, then we can push `n` microbatches into the pipeline, each microbatch containing 1 sample. Then at any timeslot, the pipeline is processing `n` samples, yielding a throughput of `n/t`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates | ||
# Minimal effort to run this code: | ||
# $ torchrun --nproc-per-node 4 pippy_unroll.py | ||
|
||
import os | ||
import torch | ||
import torch.distributed as dist | ||
|
||
from pippy.IR import annotate_split_points, Pipe, PipeSplitWrapper | ||
from pippy.PipelineStage import PipelineStage | ||
|
||
|
||
class IterationBlock(torch.nn.Module): | ||
def __init__(self, d_hid): | ||
super().__init__() | ||
self.lin = torch.nn.Linear(d_hid, d_hid) | ||
|
||
def forward(self, x): | ||
x = self.lin(x) | ||
x = torch.relu(x) | ||
return x | ||
|
||
|
||
class IterativeNetwork(torch.nn.Module): | ||
def __init__(self, d_hid, num_iters): | ||
super().__init__() | ||
self.num_iters = num_iters | ||
self.iter_block = IterationBlock(d_hid) | ||
# 10 output classes | ||
self.output_proj = torch.nn.Linear(d_hid, 10) | ||
|
||
def forward(self, x): | ||
for i in range(self.num_iters): | ||
x = self.iter_block(x) | ||
return self.output_proj(x) | ||
|
||
|
||
# We are using `torchrun` to run this example with multiple processes. | ||
# `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`. | ||
torch.manual_seed(0) | ||
rank = int(os.environ["RANK"]) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
|
||
# Figure out device to use | ||
if torch.cuda.is_available(): | ||
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") | ||
else: | ||
device = torch.device("cpu") | ||
|
||
# Create the model | ||
d_hid = 512 | ||
# (n-1) iterations + 1 output projection | ||
num_iters = world_size - 1 | ||
model = IterativeNetwork(d_hid, num_iters).to(device) | ||
|
||
# Add a split point after each iter_block | ||
annotate_split_points( | ||
model, | ||
{"iter_block": PipeSplitWrapper.SplitPoint.END}, | ||
) | ||
|
||
batch_size = 32 | ||
example_input = torch.randn(batch_size, d_hid, device=device) | ||
chunks = world_size | ||
|
||
pipe = Pipe.from_tracing(model, chunks, example_args=(example_input,)) | ||
|
||
if rank == 0: | ||
print(" pipe ".center(80, "*")) | ||
print(pipe) | ||
print(" submod0 ".center(80, "*")) | ||
print(pipe.split_gm.submod_0) | ||
|
||
# Initialize distributed environment | ||
dist.init_process_group(rank=rank, world_size=world_size) | ||
|
||
# Pipeline stage is our main pipeline runtime. It takes in the pipe object, | ||
# the rank of this process, and the device. | ||
stage = PipelineStage(pipe, rank, device) | ||
|
||
# Input data | ||
x = torch.randn(batch_size, d_hid, device=device) | ||
|
||
# Run the pipeline with input `x`. Divide the batch into n micro-batches | ||
# and run them in parallel on the pipeline | ||
if rank == 0: | ||
stage(x) | ||
elif rank == world_size - 1: | ||
output = stage() | ||
else: | ||
stage() | ||
|
||
if rank == world_size - 1: | ||
# Run the original code and get the output for comparison | ||
reference_output = model(x) | ||
# Compare numerics of pipeline and original model | ||
torch.testing.assert_close(output, reference_output) | ||
print(" Pipeline parallel model ran successfully! ".center(80, "*")) |