From ef194cd2345055b142407bf75c58e1e2a2d0865e Mon Sep 17 00:00:00 2001 From: anj-s <32556631+anj-s@users.noreply.github.com> Date: Wed, 17 Nov 2021 14:31:00 -0800 Subject: [PATCH] [feature] Add a OffloadConfig object to specify offloading params to disk. (#855) * fixed lint issues * remove unused print statements * add changelog entry * [skip ci] fix lint errors --- CHANGELOG.md | 2 + benchmarks/datasets/wikitext2_data.py | 1 + .../experimental_async_approaches.py | 4 +- fairscale/nn/data_parallel/__init__.py | 2 +- .../fully_sharded_data_parallel.py | 33 ++++-- pyproject.toml | 2 +- tests/nn/data_parallel/test_fsdp_offload.py | 101 +++++++++--------- 7 files changed, 84 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e979254e9..d15f1467a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Sharded Grad Scaler works with cpu offload in mixed and full precision. [#831] +- API for specifying SSD offload for params with FSDP. You can use a OffloadConfig to specify the type of offload + and the file path for storing params on SSD. Note: This is an experimental feature. [#855] ### Changed - Cleanup: Moving forward we would be testing all of our code with Python 3.9.7, CUDA 11.2 and the following three versions of PyTorch [#847]: diff --git a/benchmarks/datasets/wikitext2_data.py b/benchmarks/datasets/wikitext2_data.py index cd921abbc..c8b17d45a 100644 --- a/benchmarks/datasets/wikitext2_data.py +++ b/benchmarks/datasets/wikitext2_data.py @@ -10,6 +10,7 @@ import torch from torch.utils.data import DataLoader + import torchtext from torchtext.data.utils import get_tokenizer from torchtext.utils import download_from_url, extract_archive diff --git a/benchmarks/experimental/experimental_async_approaches.py b/benchmarks/experimental/experimental_async_approaches.py index a84ff50b9..2496bcc84 100644 --- a/benchmarks/experimental/experimental_async_approaches.py +++ b/benchmarks/experimental/experimental_async_approaches.py @@ -18,8 +18,6 @@ import torch.nn as nn from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader -import torchtext -from torchtext.data.utils import get_tokenizer from fairscale.experimental.nn.ampnet_pipe import pipe from fairscale.nn.model_parallel import initialize_model_parallel @@ -27,6 +25,8 @@ from fairscale.nn.pipe import LazyModule from fairscale.optim import GradScaler from fairscale.utils.testing import dist_init, get_worker_map +import torchtext +from torchtext.data.utils import get_tokenizer try: from fairscale.optim import Adam # type: ignore diff --git a/fairscale/nn/data_parallel/__init__.py b/fairscale/nn/data_parallel/__init__.py index e7abd5b30..3d85a163a 100644 --- a/fairscale/nn/data_parallel/__init__.py +++ b/fairscale/nn/data_parallel/__init__.py @@ -5,7 +5,7 @@ from typing import List -from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState, auto_wrap_bn +from .fully_sharded_data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState, auto_wrap_bn from .sharded_ddp import ShardedDataParallel __all__: List[str] = [] diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index d60e7cd7a..9b7742fcc 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -5,12 +5,13 @@ import contextlib import copy +from dataclasses import dataclass from enum import Enum, auto import functools import logging from math import inf import os -from random import randint +import tempfile import time import traceback import typing @@ -100,6 +101,19 @@ class TrainingState(Enum): SUMMON_FULL_PARAMS = auto() +# Data classes containing FSDP parameter constructs + +# Offload config for specifying SSD options (initially at least) +@dataclass +class OffloadConfig: + """Class for specifying all arguments related to offloading parameters.""" + + # Offload type: currently only supports: "ssd_offload" + offload_type: str = None + # Path to the directory for storing parameters offloaded to disk. + ssd_filepath_dir: str = None + + class FullyShardedDataParallel(nn.Module): """ A wrapper for sharding Module parameters across data parallel workers. This @@ -260,6 +274,10 @@ class FullyShardedDataParallel(nn.Module): cpu_offload (bool, Optional): if ``True``, offload params to CPU. Note: This arg will be deprecated in favor of *``move_params_to_cpu``* in an upcoming release. + offload_config (OffloadConfig): + The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and + other required knobs when offloading parameters from GPU. Currently the OffloadConfig + only supports specifying SSD offload as an option. Note: This is an experimental feature. """ def __init__( @@ -282,7 +300,7 @@ def __init__( force_input_to_fp32: bool = False, verbose: bool = False, cpu_offload: bool = False, - **kwargs: Dict[str, Any], + offload_config: OffloadConfig = None, ): init_start = time.time() super().__init__() @@ -306,7 +324,7 @@ def __init__( self.force_input_to_fp32 = force_input_to_fp32 self.verbose = verbose # Experimental feature for now. Use at your own risk. - self.ssd_offload = kwargs.get("ssd_offload", False) + self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size) self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor @@ -339,12 +357,13 @@ def __init__( # TODO(anj): Should we conditionally do this only if we have params? # TODO(anj): Figure out if we can allocate the buffer during sharding. self.buffer_size = sum(p.numel() for p in params) - self.ssd_buffer_filename = "" if self.ssd_offload: assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature." - # TODO(anj): Add support for temp file and directory as possible API params. - self.ssd_buffer_filename = f"{randint(1, int(10E6))}_rank{self.rank}" - self.ssd_buffer = ssd_offload.SsdBuffer(self.buffer_size, self.ssd_buffer_filename) + self.ssd_buffer_filepath_dir = ( + offload_config.ssd_filepath_dir if offload_config.ssd_filepath_dir else tempfile.gettempdir() + ) + self.ssd_buffer_filename = tempfile.mkstemp(dir=self.ssd_buffer_filepath_dir) + self.ssd_buffer = ssd_offload.SsdBuffer(self.buffer_size, self.ssd_buffer_filename[1]) self.move_grads_to_cpu = True self.move_params_to_cpu = True diff --git a/pyproject.toml b/pyproject.toml index 608a2f081..2dca6251e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,4 +27,4 @@ use_parentheses = true skip_glob = ["build/*", "stubs/*"] # Don't split "import" and "from". force_sort_within_sections = true -known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"] +known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchvision"] diff --git a/tests/nn/data_parallel/test_fsdp_offload.py b/tests/nn/data_parallel/test_fsdp_offload.py index e06f2b395..880074c60 100644 --- a/tests/nn/data_parallel/test_fsdp_offload.py +++ b/tests/nn/data_parallel/test_fsdp_offload.py @@ -4,10 +4,9 @@ # LICENSE file in the root directory of this source tree. import functools -import glob import itertools -import os import sys +import tempfile import time import unittest @@ -18,11 +17,12 @@ import torch.distributed from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper -from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState +from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState from fairscale.utils import torch_version -from fairscale.utils.testing import dist_init, rmf, spawn_for_all_world_sizes +from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes # Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release. +print(f"torch version {torch_version()}") pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0") @@ -32,8 +32,6 @@ class DistributedTest(unittest.TestCase): def setUp(self): - if torch_version() < (1, 6, 0): - raise unittest.SkipTest("Need pytorch version >= 1.6 due to lack of reduce_scatter") if not torch.cuda.is_available(): raise unittest.SkipTest("CUDA not available, skipping test") if sys.platform == "win32": @@ -102,8 +100,12 @@ def _test_identical_outputs_eval( ref_state_dict[k] = ref_state_dict[k].cpu() # Confirm we get the same behavior using FullyShardedDataParallel. + if config.get("ssd_offload", False): + config["offload_config"] = OffloadConfig(offload_type="ssd_offload") + + del config["ssd_offload"] model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) - if not config.get("ssd_offload", False): + if not model.ssd_offload and not model.move_params_to_cpu: if use_cuda: model = model.cuda() else: @@ -149,16 +151,14 @@ def _test_memory_benchmark(self, rank, group, config): model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4) time_keeper.print_time("CPU_MODEL", 1.0) - config["ssd_offload"] = True - model = FullyShardedDataParallel(model, **config) - time_keeper.print_time("FSDP_MODEL", 1.0) + with tempfile.TemporaryDirectory() as current_tempdir: + config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir) - self._eval_for_several_steps(model, 1, autocast=False) - time_keeper.print_time("EVAL") + model = FullyShardedDataParallel(model, **config) + time_keeper.print_time("FSDP_MODEL", 1.0) - fileList = glob.glob(os.getcwd() + "/*_rank*") - for file in fileList: - rmf(file) + self._eval_for_several_steps(model, 1, autocast=False) + time_keeper.print_time("EVAL") class SimpleLinear(nn.Module): @@ -221,29 +221,33 @@ def _test_named_params(self, rank, group, config): before_wrap_model = TransformerWithSharedParams(group) before_wrap_params = before_wrap_model.named_parameters() - config["ssd_offload"] = True - model = FullyShardedDataParallel(before_wrap_model, **config) + with tempfile.TemporaryDirectory() as current_tempdir: + if config["ssd_offload"]: + config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir) + del config["ssd_offload"] - if not config["ssd_offload"]: - model = model.cuda() + model = FullyShardedDataParallel(before_wrap_model, **config) + print(f"model.ssd_offload {model.ssd_offload}") + if not model.ssd_offload and not model.move_params_to_cpu: + model = model.cuda() - self._eval_with_config(model, autocast=config["mixed_precision"]) + self._eval_with_config(model, autocast=config["mixed_precision"]) - # Get the named parameters after wrapping to compare. - after_wrap_params = model.named_parameters() + # Get the named parameters after wrapping to compare. + after_wrap_params = model.named_parameters() - if not config.get("flatten_parameters", False): - for before_nm, after_nm in zip(before_wrap_params, after_wrap_params): - assert before_nm[0] == after_nm[0] - else: - named_params_flat = [p for p in after_wrap_params][0][0] - assert "flat_param_0" in named_params_flat + if not config.get("flatten_parameters", False): + for before_nm, after_nm in zip(before_wrap_params, after_wrap_params): + assert before_nm[0] == after_nm[0] + else: + named_params_flat = [p for p in after_wrap_params][0][0] + assert "flat_param_0" in named_params_flat - after_wrap_params = model.named_parameters() + after_wrap_params = model.named_parameters() - for before_nm, after_nm_original in zip(before_wrap_params, after_wrap_params): - assert before_nm[0] == after_nm_original[0] - torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].shape) + for before_nm, after_nm_original in zip(before_wrap_params, after_wrap_params): + assert before_nm[0] == after_nm_original[0] + torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].shape) class TestSsdLoading(DistributedTest): @@ -252,7 +256,7 @@ def test_ssd_offloading_eval(self, config): test_fn = functools.partial(self._test_ssd_offload_eval, config=config) spawn_and_init(test_fn) - @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) + @parameterized.expand(CONFIG, name_func=rename_test) def test_transformer_parameterized(self, config): spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config)) @@ -264,26 +268,23 @@ def _test_ssd_offload_eval(self, rank, group, config): nested_wrapping = config["nested_wrapping"] del config["nested_wrapping"] - config["ssd_offload"] = True - if nested_wrapping: - model = FullyShardedDataParallel(NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)) - else: - model = FullyShardedDataParallel(model, **config) - - if not config["ssd_offload"]: - model = model.cuda() - self._eval_with_config(model, autocast=config["mixed_precision"]) + with tempfile.TemporaryDirectory() as current_tempdir: + config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir) + if nested_wrapping: + model = FullyShardedDataParallel( + NestedWrappedModule(group, wrap_everything=True, wrapper_config=config) + ) + else: + model = FullyShardedDataParallel(model, **config) - # With SSD offload only local_state_dict will work. We can support global - # state dict if we think it is necessary. - state_dict = model.local_state_dict() - model.load_local_state_dict(state_dict) + self._eval_with_config(model, autocast=config["mixed_precision"]) - self._eval_with_config(model, config["mixed_precision"]) + # With SSD offload only local_state_dict will work. We can support global + # state dict if we think it is necessary. + state_dict = model.local_state_dict() + model.load_local_state_dict(state_dict) - fileList = glob.glob(os.getcwd() + "/*_rank*") - for file in fileList: - rmf(file) + self._eval_with_config(model, config["mixed_precision"]) class TransformerWithSharedParams(nn.Module):