Skip to content

Commit

Permalink
[fix] [MEVO]: make mevo work with eval and optim_state checkpointing (#…
Browse files Browse the repository at this point in the history
…851)

* [fix]: fix eval for shared weight FSDP

* fixing optim state saving

* add changelog

* reformat with newer local isort

* update test

* avoid computing reference state unless we are testing training

* added optim_state test

* make mypy happy

* move tests; maybe we need to CUDA memory related tests in the first of the lists

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Nov 18, 2021
1 parent fd831c4 commit 0db50ce
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
and the file path for storing params on SSD. Note: This is an experimental feature. [#855]

### Changed
- MEVO: fixed eval and checkpointing code paths [#851]
- Cleanup: Moving forward we would be testing all of our code with Python 3.9.7, CUDA 11.2 and the following three versions of PyTorch [#847]:
- the most recent stable version
- the most recent LTS version
Expand Down
1 change: 0 additions & 1 deletion benchmarks/datasets/wikitext2_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch
from torch.utils.data import DataLoader

import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/experimental/experimental_async_approaches.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
import torch.nn as nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer

from fairscale.experimental.nn.ampnet_pipe import pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule
from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map
import torchtext
from torchtext.data.utils import get_tokenizer

try:
from fairscale.optim import Adam # type: ignore
Expand Down
5 changes: 3 additions & 2 deletions fairscale/experimental/nn/mevo.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def forward():
def __init__(self, linked_param: torch.Tensor):
super().__init__()
assert isinstance(linked_param, nn.Parameter)
self.trigger = nn.Parameter(torch.rand(1, dtype=linked_param.dtype))
self.trigger = nn.Parameter(torch.rand(1, dtype=linked_param.dtype, device=linked_param.device))
self.trigger._linked_param = linked_param

def forward(self) -> torch.Tensor: # type: ignore
Expand Down Expand Up @@ -437,7 +437,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: #
print("DEBUG cur, peak", cur_mem, mem)
assert isinstance(input, torch.Tensor)
assert isinstance(target, torch.Tensor)
assert input.requires_grad
if torch.is_grad_enabled():
assert input.requires_grad
input, target = _reshape_inputs(input, target)

tokens, d_model = input.shape
Expand Down
38 changes: 28 additions & 10 deletions fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
# LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy
from typing import Any, Dict, Iterator, List, Tuple, cast
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, cast

import torch

from fairscale.nn.misc import FlattenParamsWrapper

if TYPE_CHECKING:
from fairscale.nn.data_parallel import FullyShardedDataParallel

# These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}

Expand Down Expand Up @@ -84,10 +87,11 @@ def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_

def _unflatten_optim_state(
combined_state: Dict[int, Dict],
instance_list: List[torch.nn.Module],
instance_list: List["FullyShardedDataParallel"],
world_pad_info: List[List[List[int]]],
singleton_state: Dict[int, Dict],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
"""Convert optimizer state for flattened parameters into original, unflatten ones."""
# local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state
next_global_id = 0 # gets incremented
Expand All @@ -100,7 +104,13 @@ def _unflatten_optim_state(

# Local corresponds to flattened, global corresponds to unflattened.
# Casting needed only for mypy.
num_global_params = [cast(int, m.num_params_managed) for m in instance_list]
num_global_params: List[int] = []
for m in instance_list:
if m.flatten_parameters:
num_flatten = cast(int, m.num_params_managed)
num_global_params.append(num_flatten)
else:
num_global_params.append(len(m.non_shared_params()))
global_to_local_id = {}
for local_id, num_unflat in enumerate(num_global_params):
for _ in range(num_unflat):
Expand Down Expand Up @@ -129,18 +139,26 @@ def _unflatten_optim_state(
assert isinstance(v, list), f"got {k}: {v} for {local_id}"
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])]
flat_buffer = torch.cat(v_unpad)
# Casting needed only for mypy.
param_views: Iterator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views([flat_buffer])
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view
unflat_state[global_id].update(singleton_state[local_id])
if instance_list[local_id].flatten_parameters:
# Unflatten. Casting needed only for mypy.
param_views: Iterator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views(
[flat_buffer]
)
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view
else:
# Copy non-flatten state directly.
assert len(local_to_global[local_id]) == 1, "Only support a single non-flatten parameter"
global_id = local_to_global[local_id][0]
unflat_state[global_id][k] = flat_buffer
unflat_state[global_id].update(singleton_state[local_id])

return unflat_state, global_to_local_id


def build_unflat_state_dict(
instance_list: List[torch.nn.Module],
instance_list: List["FullyShardedDataParallel"],
world_pad_info: List[List[List[int]]],
state: Dict[int, Dict[str, List[torch.Tensor]]],
singleton_state: Dict[int, Dict[str, List[torch.Tensor]]],
Expand Down
36 changes: 31 additions & 5 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def __init__(
params.append(param)

self._has_params = len(params) > 0
self._has_shared_params = False

# TODO(anj): Should we conditionally do this only if we have params?
# TODO(anj): Figure out if we can allocate the buffer during sharding.
Expand Down Expand Up @@ -492,6 +493,14 @@ def append_shared_param(self, p: Parameter) -> None:
len(list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))) > 0
), "Must have at least 1 non-shared param."
self.params.append(p)
self._has_shared_params = True

def non_shared_params(self) -> List[nn.Parameter]:
"""Return the list of non-shared parameters."""
if self._has_shared_params:
return list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))
else:
return self.params

def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
"""
Expand Down Expand Up @@ -1050,9 +1059,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge
non_shared_params = self.params
# filter out shared params for all but the owner FSDP module.
if len(full_tensors) < len(non_shared_params):
non_shared_params = list(
filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params)
)
non_shared_params = self.non_shared_params()
assert len(full_tensors) == len(
non_shared_params
), f"{len(full_tensors)} vs. {len(non_shared_params)}"
Expand Down Expand Up @@ -1809,6 +1816,18 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:

self.has_full_params = False

if self._has_shared_params:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another FSDP instance. An example is that in eval case
# with reshard_after_forward=False but the sharing instance has
# reshard_after_forward=True. Then, on the second forward, the
# other instance can shard the shared param and but this instance
# can mistakenly think the full param is already gathered from the
# has_full_params flag.
#
# Therefore, we update the flag accordingly here.
self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params)

# Early exit if we already have full params and don't need full precision.
if self.has_full_params and not force_full_precision:
for p in self.params:
Expand Down Expand Up @@ -2148,7 +2167,14 @@ def _gather_optim_state(
for k, v in sd_state.items():
gathered_state[k] = {}
singleton_state[k] = {}
desired_buffer_size = self._fsdp_instances[k].flat_param._full_param_padded.size() # type: ignore
# For shared params, we are not flattening. We have only 1 non-shared
# param that has the optimizer state. So we handle it with the correct
# parameter list.
non_shared_params = cast(FullyShardedDataParallel, self._fsdp_instances[k]).non_shared_params()
assert (
len(non_shared_params) == 1
), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)}"
desired_buffer_size = non_shared_params[0]._full_param_padded.size()
buffer = None # for sharded tensors
singleton_buffer = None # for singleton tensors
for buffer_name, t in v.items():
Expand Down Expand Up @@ -2214,7 +2240,7 @@ def gather_full_optim_state_dict(self, optim: torch.optim.Optimizer, **ignored:
return new_state_dict

@property
def _fsdp_instances(self) -> List[nn.Module]:
def _fsdp_instances(self) -> List["FullyShardedDataParallel"]:
"""Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]

Expand Down
4 changes: 3 additions & 1 deletion fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def flat_param(self) -> nn.Parameter:
"""We used to support only a single flat_param. This allows us to
be backward compatible.
"""
assert len(self.flat_params) == 1, "Incorrect access to flat_param"
assert (
len(self.flat_params) == 1
), f"Incorrect access to flat_param: len(self.flat_params)={len(self.flat_params)}"
return self.flat_params[0]

def _init_flatten_params(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchvision"]
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"]
2 changes: 0 additions & 2 deletions tests/ci_test_list_2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ tests/utils/test_containers.py
tests/utils/test_parallel.py
tests/utils/test_state_dict.py
tests/utils/test_version.py
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
tests/nn/wrap/test_wrap.py
Expand Down
2 changes: 2 additions & 0 deletions tests/ci_test_list_3.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
Expand Down
63 changes: 49 additions & 14 deletions tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, with_fsdp=False, wrap_middle="none"):
self.ln2 = nn.LayerNorm(D_MODEL).cuda().half()

if with_fsdp:
# Shared layers much be un-flatten.
# Shared layers must be un-flatten.
self.l0 = FSDP(self.l0, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16)
self.l1 = FSDP(self.l1, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16)
self.l1.append_shared_param(self.l0.module.weight)
Expand Down Expand Up @@ -89,38 +89,46 @@ def temp_files():

@skip_if_single_gpu
@pytest.mark.parametrize("wrap_middle", ["none", "flat", "nonflat"])
def test_shared_weight_mevo(temp_files, wrap_middle):
@pytest.mark.parametrize("test_fn", ["train", "eval", "optim_state"])
def test_shared_weight_mevo(temp_files, wrap_middle, test_fn):
"""Test FSDP with a model with shared weights."""
if test_fn == "optim_state":
if wrap_middle != "flat":
pytest.skip("only support optim_state when root and middle part is flat")

world_size = 2

# Get ref.
model = Model()
sd_before = deepcopy(model.state_dict())
in_data = (torch.rand(BS, SEQ) * (VOCAB - 1)).cuda().long()
_train(model, in_data, world_size)
sd_after = deepcopy(model.state_dict())
# Before and after state should not be equal.
assert not objects_are_equal(sd_before, sd_after)
if test_fn == "train":
_train(model, in_data, world_size)
sd_after = deepcopy(model.state_dict())
# Before and after state should not be equal.
assert not objects_are_equal(sd_before, sd_after)

# Save data
torch.save(sd_before, temp_files[2])
torch.save(sd_after, temp_files[3])
if test_fn == "train":
torch.save(sd_after, temp_files[3])
torch.save(in_data, temp_files[4])

# Run FSDP
mp.spawn(
_dist_worker,
(world_size, temp_files, wrap_middle),
(world_size, temp_files, wrap_middle, test_fn),
nprocs=world_size,
)


def _dist_worker(rank, world_size, files, wrap_middle):
def _dist_worker(rank, world_size, files, wrap_middle, test_fn):

# Get data from files.
file1, file2, sd_before, sd_after, in_data = files
sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank))
sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
if test_fn == "train":
sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank))

result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2)
Expand All @@ -130,19 +138,46 @@ def _dist_worker(rank, world_size, files, wrap_middle):
# To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping
# and make that work.
Model(with_fsdp=True, wrap_middle=wrap_middle),
flatten_parameters=False,
flatten_parameters=test_fn == "optim_state",
mixed_precision=False,
compute_dtype=torch.float16,
)
fsdp_model.load_state_dict(sd_before)

_train(fsdp_model, in_data)

objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)
if test_fn == "train":
_train(fsdp_model, in_data)
objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)
elif test_fn == "eval":
_eval(fsdp_model, in_data)
elif test_fn == "optim_state":
optim = SGD(fsdp_model.parameters(), lr=0.1)
for _ in range(3):
out = fsdp_model(in_data)
out.backward()
optim.step()
sd = fsdp_model.gather_full_optim_state_dict(optim)
if rank == 0:
# There should 8 momentum buffers in the state.
assert len(sd["state"].keys()) == 8
else:
assert sd is None, "only rank 0 should have the optim state"
else:
assert 0, f"invalid test_fn {test_fn}"

teardown()


def _eval(model, in_data):
# run in eval mode
model.eval()
for _ in range(5):
out = model(in_data)
# adding torch.no_grad()
for _ in range(5):
with torch.no_grad():
out = model(in_data)


def _train(model, in_data, steps_per_iter=1):
optim = SGD(model.parameters(), lr=0.1)
for _ in range(3):
Expand Down

0 comments on commit 0db50ce

Please sign in to comment.