From 0db50ce556a6755ef7eb5cecc0dded9bef10ca85 Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Thu, 18 Nov 2021 13:40:59 -0800 Subject: [PATCH] [fix] [MEVO]: make mevo work with eval and optim_state checkpointing (#851) * [fix]: fix eval for shared weight FSDP * fixing optim state saving * add changelog * reformat with newer local isort * update test * avoid computing reference state unless we are testing training * added optim_state test * make mypy happy * move tests; maybe we need to CUDA memory related tests in the first of the lists Co-authored-by: Min Xu --- CHANGELOG.md | 1 + benchmarks/datasets/wikitext2_data.py | 1 - .../experimental_async_approaches.py | 4 +- fairscale/experimental/nn/mevo.py | 5 +- .../nn/data_parallel/fsdp_optim_utils.py | 38 ++++++++--- .../fully_sharded_data_parallel.py | 36 +++++++++-- fairscale/nn/misc/flatten_params_wrapper.py | 4 +- pyproject.toml | 2 +- tests/ci_test_list_2.txt | 2 - tests/ci_test_list_3.txt | 2 + .../test_fsdp_shared_weights_mevo.py | 63 ++++++++++++++----- 11 files changed, 120 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d15f1467a..b723f2875 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 and the file path for storing params on SSD. Note: This is an experimental feature. [#855] ### Changed +- MEVO: fixed eval and checkpointing code paths [#851] - Cleanup: Moving forward we would be testing all of our code with Python 3.9.7, CUDA 11.2 and the following three versions of PyTorch [#847]: - the most recent stable version - the most recent LTS version diff --git a/benchmarks/datasets/wikitext2_data.py b/benchmarks/datasets/wikitext2_data.py index c8b17d45a..cd921abbc 100644 --- a/benchmarks/datasets/wikitext2_data.py +++ b/benchmarks/datasets/wikitext2_data.py @@ -10,7 +10,6 @@ import torch from torch.utils.data import DataLoader - import torchtext from torchtext.data.utils import get_tokenizer from torchtext.utils import download_from_url, extract_archive diff --git a/benchmarks/experimental/experimental_async_approaches.py b/benchmarks/experimental/experimental_async_approaches.py index 2496bcc84..a84ff50b9 100644 --- a/benchmarks/experimental/experimental_async_approaches.py +++ b/benchmarks/experimental/experimental_async_approaches.py @@ -18,6 +18,8 @@ import torch.nn as nn from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader +import torchtext +from torchtext.data.utils import get_tokenizer from fairscale.experimental.nn.ampnet_pipe import pipe from fairscale.nn.model_parallel import initialize_model_parallel @@ -25,8 +27,6 @@ from fairscale.nn.pipe import LazyModule from fairscale.optim import GradScaler from fairscale.utils.testing import dist_init, get_worker_map -import torchtext -from torchtext.data.utils import get_tokenizer try: from fairscale.optim import Adam # type: ignore diff --git a/fairscale/experimental/nn/mevo.py b/fairscale/experimental/nn/mevo.py index a79fcc711..6412ff034 100644 --- a/fairscale/experimental/nn/mevo.py +++ b/fairscale/experimental/nn/mevo.py @@ -378,7 +378,7 @@ def forward(): def __init__(self, linked_param: torch.Tensor): super().__init__() assert isinstance(linked_param, nn.Parameter) - self.trigger = nn.Parameter(torch.rand(1, dtype=linked_param.dtype)) + self.trigger = nn.Parameter(torch.rand(1, dtype=linked_param.dtype, device=linked_param.device)) self.trigger._linked_param = linked_param def forward(self) -> torch.Tensor: # type: ignore @@ -437,7 +437,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # print("DEBUG cur, peak", cur_mem, mem) assert isinstance(input, torch.Tensor) assert isinstance(target, torch.Tensor) - assert input.requires_grad + if torch.is_grad_enabled(): + assert input.requires_grad input, target = _reshape_inputs(input, target) tokens, d_model = input.shape diff --git a/fairscale/nn/data_parallel/fsdp_optim_utils.py b/fairscale/nn/data_parallel/fsdp_optim_utils.py index 99a9a6290..d5d08647c 100644 --- a/fairscale/nn/data_parallel/fsdp_optim_utils.py +++ b/fairscale/nn/data_parallel/fsdp_optim_utils.py @@ -4,12 +4,15 @@ # LICENSE file in the root directory of this source tree. """These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states.""" import copy -from typing import Any, Dict, Iterator, List, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, cast import torch from fairscale.nn.misc import FlattenParamsWrapper +if TYPE_CHECKING: + from fairscale.nn.data_parallel import FullyShardedDataParallel + # These return keys are used by fairseq. To change, add @sshleifer as a reviewer. UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"} @@ -84,10 +87,11 @@ def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_ def _unflatten_optim_state( combined_state: Dict[int, Dict], - instance_list: List[torch.nn.Module], + instance_list: List["FullyShardedDataParallel"], world_pad_info: List[List[List[int]]], singleton_state: Dict[int, Dict], ) -> Tuple[Dict[int, Dict], Dict[int, int]]: + """Convert optimizer state for flattened parameters into original, unflatten ones.""" # local ids are the keys in the current state (combined_state), (usually fewer) # global ids will be the keys in the unflattened state next_global_id = 0 # gets incremented @@ -100,7 +104,13 @@ def _unflatten_optim_state( # Local corresponds to flattened, global corresponds to unflattened. # Casting needed only for mypy. - num_global_params = [cast(int, m.num_params_managed) for m in instance_list] + num_global_params: List[int] = [] + for m in instance_list: + if m.flatten_parameters: + num_flatten = cast(int, m.num_params_managed) + num_global_params.append(num_flatten) + else: + num_global_params.append(len(m.non_shared_params())) global_to_local_id = {} for local_id, num_unflat in enumerate(num_global_params): for _ in range(num_unflat): @@ -129,18 +139,26 @@ def _unflatten_optim_state( assert isinstance(v, list), f"got {k}: {v} for {local_id}" v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] flat_buffer = torch.cat(v_unpad) - # Casting needed only for mypy. - param_views: Iterator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views([flat_buffer]) - for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views): - assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}" - unflat_state[global_id][k] = param_view - unflat_state[global_id].update(singleton_state[local_id]) + if instance_list[local_id].flatten_parameters: + # Unflatten. Casting needed only for mypy. + param_views: Iterator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views( + [flat_buffer] + ) + for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views): + assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}" + unflat_state[global_id][k] = param_view + else: + # Copy non-flatten state directly. + assert len(local_to_global[local_id]) == 1, "Only support a single non-flatten parameter" + global_id = local_to_global[local_id][0] + unflat_state[global_id][k] = flat_buffer + unflat_state[global_id].update(singleton_state[local_id]) return unflat_state, global_to_local_id def build_unflat_state_dict( - instance_list: List[torch.nn.Module], + instance_list: List["FullyShardedDataParallel"], world_pad_info: List[List[List[int]]], state: Dict[int, Dict[str, List[torch.Tensor]]], singleton_state: Dict[int, Dict[str, List[torch.Tensor]]], diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 9b7742fcc..a2c57a42c 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -353,6 +353,7 @@ def __init__( params.append(param) self._has_params = len(params) > 0 + self._has_shared_params = False # TODO(anj): Should we conditionally do this only if we have params? # TODO(anj): Figure out if we can allocate the buffer during sharding. @@ -492,6 +493,14 @@ def append_shared_param(self, p: Parameter) -> None: len(list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))) > 0 ), "Must have at least 1 non-shared param." self.params.append(p) + self._has_shared_params = True + + def non_shared_params(self) -> List[nn.Parameter]: + """Return the list of non-shared parameters.""" + if self._has_shared_params: + return list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params)) + else: + return self.params def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": """ @@ -1050,9 +1059,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge non_shared_params = self.params # filter out shared params for all but the owner FSDP module. if len(full_tensors) < len(non_shared_params): - non_shared_params = list( - filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params) - ) + non_shared_params = self.non_shared_params() assert len(full_tensors) == len( non_shared_params ), f"{len(full_tensors)} vs. {len(non_shared_params)}" @@ -1809,6 +1816,18 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: self.has_full_params = False + if self._has_shared_params: + # self.has_full_params flag can be out of sync if a shared param is + # sharded by another FSDP instance. An example is that in eval case + # with reshard_after_forward=False but the sharing instance has + # reshard_after_forward=True. Then, on the second forward, the + # other instance can shard the shared param and but this instance + # can mistakenly think the full param is already gathered from the + # has_full_params flag. + # + # Therefore, we update the flag accordingly here. + self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params) + # Early exit if we already have full params and don't need full precision. if self.has_full_params and not force_full_precision: for p in self.params: @@ -2148,7 +2167,14 @@ def _gather_optim_state( for k, v in sd_state.items(): gathered_state[k] = {} singleton_state[k] = {} - desired_buffer_size = self._fsdp_instances[k].flat_param._full_param_padded.size() # type: ignore + # For shared params, we are not flattening. We have only 1 non-shared + # param that has the optimizer state. So we handle it with the correct + # parameter list. + non_shared_params = cast(FullyShardedDataParallel, self._fsdp_instances[k]).non_shared_params() + assert ( + len(non_shared_params) == 1 + ), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)}" + desired_buffer_size = non_shared_params[0]._full_param_padded.size() buffer = None # for sharded tensors singleton_buffer = None # for singleton tensors for buffer_name, t in v.items(): @@ -2214,7 +2240,7 @@ def gather_full_optim_state_dict(self, optim: torch.optim.Optimizer, **ignored: return new_state_dict @property - def _fsdp_instances(self) -> List[nn.Module]: + def _fsdp_instances(self) -> List["FullyShardedDataParallel"]: """Returns all fsdp modules in self.modules() including self.""" return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 38efaec09..4dbfdf86d 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -238,7 +238,9 @@ def flat_param(self) -> nn.Parameter: """We used to support only a single flat_param. This allows us to be backward compatible. """ - assert len(self.flat_params) == 1, "Incorrect access to flat_param" + assert ( + len(self.flat_params) == 1 + ), f"Incorrect access to flat_param: len(self.flat_params)={len(self.flat_params)}" return self.flat_params[0] def _init_flatten_params( diff --git a/pyproject.toml b/pyproject.toml index 2dca6251e..608a2f081 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,4 +27,4 @@ use_parentheses = true skip_glob = ["build/*", "stubs/*"] # Don't split "import" and "from". force_sort_within_sections = true -known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchvision"] +known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"] diff --git a/tests/ci_test_list_2.txt b/tests/ci_test_list_2.txt index 72d5e1763..e0f8e0b29 100644 --- a/tests/ci_test_list_2.txt +++ b/tests/ci_test_list_2.txt @@ -11,8 +11,6 @@ tests/utils/test_containers.py tests/utils/test_parallel.py tests/utils/test_state_dict.py tests/utils/test_version.py -tests/nn/checkpoint/test_checkpoint_activations.py -tests/nn/checkpoint/test_checkpoint_activations_norm.py tests/nn/misc/test_grad_bucket.py tests/nn/misc/test_param_bucket.py tests/nn/wrap/test_wrap.py diff --git a/tests/ci_test_list_3.txt b/tests/ci_test_list_3.txt index 65d0f6a2b..1986a77c2 100644 --- a/tests/ci_test_list_3.txt +++ b/tests/ci_test_list_3.txt @@ -1,3 +1,5 @@ +tests/nn/checkpoint/test_checkpoint_activations.py +tests/nn/checkpoint/test_checkpoint_activations_norm.py tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py tests/nn/misc/test_grad_bucket.py tests/nn/misc/test_param_bucket.py diff --git a/tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py b/tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py index 1b2a91168..fdffb34da 100644 --- a/tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py +++ b/tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py @@ -50,7 +50,7 @@ def __init__(self, with_fsdp=False, wrap_middle="none"): self.ln2 = nn.LayerNorm(D_MODEL).cuda().half() if with_fsdp: - # Shared layers much be un-flatten. + # Shared layers must be un-flatten. self.l0 = FSDP(self.l0, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16) self.l1 = FSDP(self.l1, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16) self.l1.append_shared_param(self.l0.module.weight) @@ -89,38 +89,46 @@ def temp_files(): @skip_if_single_gpu @pytest.mark.parametrize("wrap_middle", ["none", "flat", "nonflat"]) -def test_shared_weight_mevo(temp_files, wrap_middle): +@pytest.mark.parametrize("test_fn", ["train", "eval", "optim_state"]) +def test_shared_weight_mevo(temp_files, wrap_middle, test_fn): """Test FSDP with a model with shared weights.""" + if test_fn == "optim_state": + if wrap_middle != "flat": + pytest.skip("only support optim_state when root and middle part is flat") + world_size = 2 # Get ref. model = Model() sd_before = deepcopy(model.state_dict()) in_data = (torch.rand(BS, SEQ) * (VOCAB - 1)).cuda().long() - _train(model, in_data, world_size) - sd_after = deepcopy(model.state_dict()) - # Before and after state should not be equal. - assert not objects_are_equal(sd_before, sd_after) + if test_fn == "train": + _train(model, in_data, world_size) + sd_after = deepcopy(model.state_dict()) + # Before and after state should not be equal. + assert not objects_are_equal(sd_before, sd_after) # Save data torch.save(sd_before, temp_files[2]) - torch.save(sd_after, temp_files[3]) + if test_fn == "train": + torch.save(sd_after, temp_files[3]) torch.save(in_data, temp_files[4]) # Run FSDP mp.spawn( _dist_worker, - (world_size, temp_files, wrap_middle), + (world_size, temp_files, wrap_middle, test_fn), nprocs=world_size, ) -def _dist_worker(rank, world_size, files, wrap_middle): +def _dist_worker(rank, world_size, files, wrap_middle, test_fn): # Get data from files. file1, file2, sd_before, sd_after, in_data = files sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank)) - sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank)) + if test_fn == "train": + sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) @@ -130,19 +138,46 @@ def _dist_worker(rank, world_size, files, wrap_middle): # To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping # and make that work. Model(with_fsdp=True, wrap_middle=wrap_middle), - flatten_parameters=False, + flatten_parameters=test_fn == "optim_state", mixed_precision=False, compute_dtype=torch.float16, ) fsdp_model.load_state_dict(sd_before) - _train(fsdp_model, in_data) - - objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) + if test_fn == "train": + _train(fsdp_model, in_data) + objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) + elif test_fn == "eval": + _eval(fsdp_model, in_data) + elif test_fn == "optim_state": + optim = SGD(fsdp_model.parameters(), lr=0.1) + for _ in range(3): + out = fsdp_model(in_data) + out.backward() + optim.step() + sd = fsdp_model.gather_full_optim_state_dict(optim) + if rank == 0: + # There should 8 momentum buffers in the state. + assert len(sd["state"].keys()) == 8 + else: + assert sd is None, "only rank 0 should have the optim state" + else: + assert 0, f"invalid test_fn {test_fn}" teardown() +def _eval(model, in_data): + # run in eval mode + model.eval() + for _ in range(5): + out = model(in_data) + # adding torch.no_grad() + for _ in range(5): + with torch.no_grad(): + out = model(in_data) + + def _train(model, in_data, steps_per_iter=1): optim = SGD(model.parameters(), lr=0.1) for _ in range(3):