Skip to content

Commit

Permalink
Add back support for stage.remap_qualname() (#934)
Browse files Browse the repository at this point in the history
## Description
`stage.remap_qualname(key)` maps a stage's parameter name (`key`) back
to the original model's parameter name.

This now works:

```
    # Stage module's state dict
    sd = stage.submod.state_dict()
    remapped_keys = [stage.remap_qualname(k) for k in sd.keys()]

    # Original model's state dict
    old_keys = mod.state_dict().keys()

    # Confirm they match
    assert all(rk in old_keys for rk in remapped_keys)
```
  • Loading branch information
kwen2501 committed Jan 26, 2024
1 parent 3632106 commit e51c8b9
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 71 deletions.
64 changes: 6 additions & 58 deletions pippy/IR.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pippy.backward import _null_coalesce_accumulate, stage_backward
from pippy.debug import PIPPY_VERBOSITY
from pippy.microbatch import LossReducer, split_args_kwargs_into_chunks
from pippy.utils import QualnameMapMixin


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -498,54 +499,6 @@ def _direct_serialization_reduce(self):
)


class QualnameMapMixin:
"""
A mixin class to provide qualname remap functionality for both Pipe object
and submodules
"""

def __init__(
self,
splitter_qualname_map: Dict[str, str] = None,
tracer_qualname_map: Dict[str, str] = None,
):
self.new_to_old_qualname_mapping: Dict[str, str] = (
splitter_qualname_map or {}
)
self.tracer_qualname_map = tracer_qualname_map

def remap_qualname(self, qualname: str):
# TODO: annoying
if qualname.startswith("split_gm."):
qualname = qualname[len("split_gm.") :]

name_before_split = None
if qualname in self.new_to_old_qualname_mapping:
name_before_split = self.new_to_old_qualname_mapping[qualname]
else:
# The qualname map does not store recursive items, thus,
# when passed a qualname with leaves, we need to perform longest prefix match
# Split from the right, one each time
split_names = qualname.rsplit(".", 1)
leaf = split_names[-1]
while len(split_names) > 1:
prefix = split_names[0]
if prefix in self.new_to_old_qualname_mapping:
old_prefix = self.new_to_old_qualname_mapping[prefix]
name_before_split = ".".join([old_prefix, leaf])
break
split_names = prefix.rsplit(".", 1)
leaf = ".".join([split_names[-1], leaf])

if name_before_split is None:
raise RuntimeError(f"Could not find mapping for {qualname}")

if self.tracer_qualname_map is not None:
return self.tracer_qualname_map[name_before_split]
else:
return name_before_split


class Pipe(QualnameMapMixin, torch.nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -615,6 +568,10 @@ def __init__(
)

# Create qualname mapping for each submodule
# Dict looks like this:
# {submod_name : Dict{old_qualname : new_qualname}}
# We save this information here for use during pipeline stage creation.
self.submod_qualname_mappings: Dict[str, Dict[str, str]] = {}
for m_qualname, mod in self.split_gm.named_children():
# "submod_x." prefix
mod_prefix = m_qualname + "."
Expand All @@ -624,16 +581,7 @@ def __init__(
# Remove prefix
new_key = k[len(mod_prefix) :]
mod_qualname_mapping.setdefault(new_key, v)
# Add a remap mixin to submodule instance
# TODO: this class change is commented out because it breaks
# recompilation if we want to recompile mod after. For example, we
# may recompile mod to modify the "device" kwarg of a `torch.ones`
# node (trace on cpu/meta, run on cuda).
# See: https://github.com/pytorch/vision/issues/5826
# mod.__class__ = type(
# "PipeStageModule", (QualnameMapMixin, mod.__class__), {}
# )
setattr(mod, "new_to_old_qualname_mapping", mod_qualname_mapping)
self.submod_qualname_mappings[m_qualname] = mod_qualname_mapping

def throw(self, *args, **kwargs):
raise RuntimeError(
Expand Down
11 changes: 9 additions & 2 deletions pippy/PipelineStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pippy.debug import map_debug_info
from pippy.IR import Pipe
from pippy.microbatch import merge_chunks, split_args_kwargs_into_chunks
from pippy.utils import flatten_args, modify_graph_op_device
from pippy.utils import flatten_args, modify_graph_op_device, QualnameMapMixin


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -52,7 +52,7 @@ class StageKwargPlaceholder:
pass


class PipelineStage(torch.nn.Module):
class PipelineStage(torch.nn.Module, QualnameMapMixin):
def __init__(
self,
pipe: Pipe,
Expand Down Expand Up @@ -98,6 +98,13 @@ def __init__(
f"{self.submod}"
)

# Enable `remap_qualname` method
QualnameMapMixin.__init__(
self,
pipe.submod_qualname_mappings[self.name],
pipe.tracer_qualname_map,
)

# Find my forward node in graph
found_node = False
for node in self.split_gm.graph.nodes:
Expand Down
49 changes: 49 additions & 0 deletions pippy/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from typing import Dict

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -92,3 +93,51 @@ def modify_graph_op_device(

if modified:
gm.recompile()


class QualnameMapMixin:
"""
A mixin class to provide qualname remap functionality for both Pipe object
and submodules
"""

def __init__(
self,
splitter_qualname_map: Dict[str, str] = None,
tracer_qualname_map: Dict[str, str] = None,
):
self.new_to_old_qualname_mapping: Dict[str, str] = (
splitter_qualname_map or {}
)
self.tracer_qualname_map = tracer_qualname_map

def remap_qualname(self, qualname: str):
# TODO: annoying
if qualname.startswith("split_gm."):
qualname = qualname[len("split_gm.") :]

name_before_split = None
if qualname in self.new_to_old_qualname_mapping:
name_before_split = self.new_to_old_qualname_mapping[qualname]
else:
# The qualname map does not store recursive items, thus,
# when passed a qualname with leaves, we need to perform longest prefix match
# Split from the right, one each time
split_names = qualname.rsplit(".", 1)
leaf = split_names[-1]
while len(split_names) > 1:
prefix = split_names[0]
if prefix in self.new_to_old_qualname_mapping:
old_prefix = self.new_to_old_qualname_mapping[prefix]
name_before_split = ".".join([old_prefix, leaf])
break
split_names = prefix.rsplit(".", 1)
leaf = ".".join([split_names[-1], leaf])

if name_before_split is None:
raise RuntimeError(f"Could not find mapping for {qualname}")

if self.tracer_qualname_map is not None:
return self.tracer_qualname_map[name_before_split]
else:
return name_before_split
10 changes: 10 additions & 0 deletions test/test_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ def run_worker(args):
f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}"
)

# Test qualname mapping
sd = stage.submod.state_dict()
print(f"Rank {args.rank} state dict keys: {sd.keys()}")
remapped_keys = [stage.remap_qualname(k) for k in sd.keys()]
print(f"Rank {args.rank} remapped keys: {remapped_keys}")
# Confirm remapped keys are consistent with original model
old_keys = mod.state_dict().keys()
assert all(rk in old_keys for rk in remapped_keys)
print(f"Qualname test passed")


def main(args=None):
parser = argparse.ArgumentParser()
Expand Down
11 changes: 0 additions & 11 deletions test/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,17 +791,6 @@ def test_remap_qualname(self):
old_name in old_names
), f"Remapped parameter {old_name} not found in {old_names}"

# Check qualname mapping for submodule
# Not supported at the moment
"""
for _, stage_mod in ec_pipe.split_gm.named_children():
for new_name, _ in stage_mod.named_parameters():
old_name = stage_mod.remap_qualname(new_name)
assert (
old_name in old_names
), f"Remapped parameter {old_name} not found in {old_names}"
"""


if __name__ == "__main__":
unittest.main()

0 comments on commit e51c8b9

Please sign in to comment.