Skip to content

Commit

Permalink
Allow sharded grad scaler to cpu offload with FSDP (#831)
Browse files Browse the repository at this point in the history
* first commit

* sharded scaler hitting nan assertions

* adding test for sharded grad scaler without cpu offload

* ddp grad scaler and fsdp sharded grad scaler test failing

* removing test_output

* fix no cpu offload test

* changing optimizer from OSS to SGD

* all tests passing, code cleanup pending

* code cleanup

* fix pyproject.toml

* removing .isort.cfg

* running isort linter

* resolving isort issues

* resolving black linter issue

* resolving mypy issues

* fix import statement

* fix mypy error

* modifying import statement

* adding pytorch version requirement

* fixing pytest skip test decorator

* apply version guard for ShardedGradScaler

* removing test_fsdp_grad_scaler

* increasing num_epochs for ShardedGradScaler so that updates are not skipped

* adding support for torch 1.8

* minor edit

* [skip ci] more torch 1.8 changes

* parametrizing the tests

* cleanup code with linters

* [skip ci] update doc string

* [skip ci] addressing some more comments
  • Loading branch information
Anupam Bhatnagar committed Nov 15, 2021
1 parent 7d7edf6 commit ba5785f
Show file tree
Hide file tree
Showing 10 changed files with 414 additions and 107 deletions.
4 changes: 2 additions & 2 deletions fairscale/experimental/nn/distributed_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@


def check_pytorch_version() -> None:
if torch_version() < (1, 9, 0):
raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher")
if torch_version() < (1, 8, 0):
raise Exception("DistributedPipeline requires PyTorch version 1.8 or higher")


MOVING_DENIED = TypeError(
Expand Down
383 changes: 362 additions & 21 deletions fairscale/optim/grad_scaler.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tests/ci_test_list_1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_grad_acc.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
Expand Down
39 changes: 24 additions & 15 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from unittest import mock

from parameterized import parameterized
import pytest
import torch
from torch import nn
import torch.distributed
Expand All @@ -29,6 +30,9 @@
spawn_for_all_world_sizes,
)

if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler

# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod

Expand All @@ -49,14 +53,17 @@ def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None
model_device = next(model.parameters()).device
# use SGD with momentum instead of Adam, since Adam is scale invariant
# and this makes it bad for tests
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

optim = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=0.9)
scaler = ShardedGradScaler()
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):
# Inputs always cuda regardless of move_grads_cpu, or model.device
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
loss = model.module.get_loss(input, output).to(model_device)
loss = scaler.scale(loss)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
if norm_type is not None:
Expand All @@ -65,10 +72,10 @@ def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None
model.clip_grad_norm_(clip_norm, norm_type)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
params = [p for p in model.parameters()]
print(f"params.device {params[0].device} param.grad.device {params[0].grad.device}")

optim.step()
scaler.step(optim)
scaler.update()
if hasattr(model, "assert_idle"):
model.assert_idle()
if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE)
return loss.detach()
Expand Down Expand Up @@ -308,21 +315,21 @@ def test_transformer_parameterized(self, config):
# Test every combination of these options:
spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config))

def test_cpu_offload_and_cpu_grads(self):
# We don't test the False condition because that requires the optimizer to internally do
# the device transfer and PyTorch optimizers don't support this.
config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True}
# testing moving params to cpu while using full and mixed precision
@parameterized.expand([(True,), (False,)], name_func=rename_test)
def test_cpu_offload_and_cpu_grads(self, mixed_precision):
config = {"mixed_precision": mixed_precision, "cpu_offload": True}
test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01
)
spawn_and_init(test_fn)

def test_cpu_offload_and_cpu_grads_no_mixed_precision(self):
# We don't test the False condition because that requires the optimizer to internally do
# the device transfer and PyTorch optimizers don't support this.
config = {"mixed_precision": False, "cpu_offload": True, "move_grads_to_cpu": True}
# testing full and mixed precision on the gpu
@parameterized.expand([(True,), (False,)], name_func=rename_test)
def test_no_cpu_offload_with_sharded_grad_scaler(self, mixed_precision):
config = {"mixed_precision": mixed_precision, "move_params_to_cpu": False}
test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=True, lr=0.01
)
spawn_and_init(test_fn)

Expand Down Expand Up @@ -485,10 +492,10 @@ def _one_step(self, model, group):
optim.step()


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestHooks(DistributedTest):
# Feel free to modify these tests as the implementation changes.
# They aspire to make sure that backward hooks are registered and used

@parameterized.expand([[True], [False]])
def test_output_backward_hooks(self, cuda_first):
fn = functools.partial(self._test_output_backward_hooks, cuda_first=cuda_first)
Expand Down Expand Up @@ -541,6 +548,7 @@ def _test_register_functions_called(self, rank, group, cuda_first=False):
assert model._register_pre_backward_hooks.called


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestNoGrad(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config):
Expand Down Expand Up @@ -568,6 +576,7 @@ def _test_transformer(self, rank, group, config):
assert objects_are_equal(ref_output, no_grad_output, raise_exception=True)


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestModuleProperties(DistributedTest):
@parameterized.expand([[{"flatten_parameters": False}], [{"flatten_parameters": True}]], name_func=rename_test)
def test_named_parameters(self, config):
Expand Down
4 changes: 4 additions & 0 deletions tests/nn/data_parallel/test_fsdp_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import unittest

from parameterized import parameterized
import pytest
import torch.nn as nn

from fairscale.utils import torch_version

from .test_fsdp import (
CONFIG_OPTIONS,
DistributedTest,
Expand All @@ -19,6 +22,7 @@
)


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestApply(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_weight_init(self, config):
Expand Down
63 changes: 0 additions & 63 deletions tests/nn/data_parallel/test_fsdp_grad_scaler.py

This file was deleted.

8 changes: 5 additions & 3 deletions tests/nn/data_parallel/test_fsdp_regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils import torch_version
from fairscale.utils.testing import (
dist_init,
Expand All @@ -47,6 +46,9 @@
torch_cuda_version,
)

if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler

# Const test params.
# Reduce iterations to 1 for debugging.
# Change world_size to 8 on beefy machines for better test coverage.
Expand Down Expand Up @@ -352,8 +354,8 @@ def dump(d):
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
@pytest.mark.parametrize("sync_bn", ["none", "pytorch"])
def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn):
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
if torch_version() < (1, 8, 0):
pytest.skip("pytorch version >= 1.8.0 required")

state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref

Expand Down
9 changes: 8 additions & 1 deletion tests/nn/data_parallel/test_fsdp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import unittest

from parameterized import parameterized
import pytest
import torch
from torch import nn

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, objects_are_equal, skip_if_cuda, teardown, temp_files_ctx

from .test_fsdp import (
Expand All @@ -23,6 +25,7 @@
)


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestLocalStateDict(DistributedTest):
@parameterized.expand([[True, True], [False, False]], name_func=rename_test)
def test_load_local_state_dict(self, flatten_params, mixed_precision):
Expand Down Expand Up @@ -50,7 +53,9 @@ def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23):
state_1_module_weight = model.module.state_dict()[weight_key]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 1, model.mixed_precision)
# increasing number of epochs from 1 to 6 for ShardedGradScaler to work properly.
# test fails for num_epochs < 6 since the updates are skipped due to gradient being inf.
self._train_for_several_steps(model, 6, model.mixed_precision)

state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
Expand All @@ -69,6 +74,7 @@ def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23):
raise AssertionError(f"params {unchanged} not changed after training")


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestSaveLoadStateDict(DistributedTest):
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_calling_state_dict_twice_mixed_precision(self, mixed_precision):
Expand Down Expand Up @@ -178,6 +184,7 @@ def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, l
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestStateDictDeviceDtype(DistributedTest):
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_state_dict_device(self, mixed_precision, cpu_offload):
Expand Down
4 changes: 4 additions & 0 deletions tests/nn/data_parallel/test_fsdp_summon_full_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import unittest

from parameterized import parameterized
import pytest
import torch

from fairscale.utils.version import torch_version

from .test_fsdp import CONFIG_OPTIONS, DistributedTest, rename_test, spawn_and_init


Expand All @@ -19,6 +22,7 @@ def get_cuda_mem():
return torch.cuda.memory_allocated()


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestMemory(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_memory(self, config):
Expand Down
6 changes: 5 additions & 1 deletion tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils import torch_version
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx

if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler

"""
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
"""
Expand Down Expand Up @@ -249,6 +251,8 @@ def test_ddp_parity(
manual_reduction,
multiple_fw,
):
if torch_version() < (1, 8, 0):
pytest.skip("pytorch version >= 1.8.0 required")
if manual_reduction and change_train_graph:
pytest.skip("Skipping changing model and grad accumulation combination, makes little sense")

Expand Down

0 comments on commit ba5785f

Please sign in to comment.