Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix][FSDP] fix weight init when using apply() (fixes #490 and #444) #543

Merged
merged 6 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 95 additions & 38 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import functools
from math import inf
import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union

import torch
from torch.autograd import Variable
Expand Down Expand Up @@ -150,6 +150,11 @@ class FullyShardedDataParallel(nn.Module):
based on world_size, so the max shard size is roughly
``bucket_cap_mb / world_size``. Values <= 0 disable bucketing.
Default: 25.
compute_device (torch.device, Optional):
device for computation. If not given and module params are on a CUDA
device, the param's device will be used. If not given and module
params are on CPU, then the current CUDA device (as indicated by
``torch.cuda.current_device()`` will be used.
"""

def __init__(
Expand All @@ -165,6 +170,7 @@ def __init__(
buffer_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None,
):
super().__init__()
self.process_group = process_group or dist.new_group()
Expand All @@ -179,14 +185,21 @@ def __init__(
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.compute_device = compute_device

if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.cpu_offload and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True")

compute_device = torch.device("cuda") if self.cpu_offload else next(module.parameters()).device
validate_process_group(compute_device, self.process_group)
if self.compute_device is None:
# Try to infer CUDA device from module parameters.
self.compute_device = next(module.parameters()).device
if self.compute_device.type != "cuda":
# Fall back to current CUDA device.
self.compute_device = torch.device("cuda")

validate_process_group(self.compute_device, self.process_group)
enable_pytorch_sync_bn(module)

# Only handle params which are not already sharded. This enables
Expand Down Expand Up @@ -239,11 +252,68 @@ def __init__(
def module(self) -> nn.Module:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance

@torch.no_grad()
def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
"""Move all buffers to the specified device and dtype, recursively."""
cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype)
self.apply(cast_fn)
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
"""
Applies ``fn`` recursively to every submodule (as returned by
``.children()``) as well as self. Typical use includes initializing the
parameters of a model.

Compared to ``torch.nn.Module.apply``, this version additionally gathers
myleott marked this conversation as resolved.
Show resolved Hide resolved
the full parameters before applying ``fn``. It should not be called from
within another ``summon_full_params`` context.

Args:
fn (nn.Module): function to be applied to each submodule

Returns:
Module: self
"""
is_uninitialized = self._is_root is None
self.assert_state(TrainingState.IDLE)
with self.summon_full_params(recurse=False):
return_value = super().apply(fn)
# summon_full_params will call _lazy_init, which sets _is_root. However,
myleott marked this conversation as resolved.
Show resolved Hide resolved
# apply() may be called directly on children instances to do weight
# init, so we should reset the _is_root flag in this case.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()
return return_value

def _cast_buffers(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, memo: Optional[Set] = None
) -> None:
"""Move all buffers to the given *device* and *dtype*.

If *device* or *dtype* are not given, then they will default to
``self.compute_device`` and ``self.buffer_dtype``, respectively. In the
case of nested FSDP instances, we will respect the child instance's
``compute_device`` and ``buffer_dtype`` configuration.

Args:
device (torch.device, Optional):
device to cast buffers to (defaults to compute_device)
dtype (torch.dtype, Optional):
dtype to cast buffers to (defaults to buffer_dtype)
memo (Set, Optional):
set of modules that have already been processed
"""
if memo is None:
memo = set()
for module in self.modules():
if module is not self and isinstance(module, FullyShardedDataParallel):
# Allow any child FSDP instances to handle their own buffers.
module._cast_buffers(device=device, dtype=dtype, memo=memo)
elif module not in memo:
memo.add(module)
for name, buf in module.named_buffers(recurse=False):
if buf is None:
continue
buf = buf.to(device=device or self.compute_device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype or self.buffer_dtype)
setattr(module, name, buf)

@property
def params_with_grad(self) -> List[Parameter]:
Expand Down Expand Up @@ -386,7 +456,10 @@ def extra_repr(self) -> str:
f"flatten_parameters={self.flatten_parameters}, "
f"cpu_offload={self.cpu_offload}, "
f"compute_dtype={self.compute_dtype}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}"
f"buffer_dtype={self.buffer_dtype}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"compute_device={self.compute_device}"
)

def __getattr__(self, name: str) -> Any:
Expand Down Expand Up @@ -443,7 +516,7 @@ def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tenso
self._lazy_init()
if self.mixed_precision:
# Buffers dtype stays consistent with parameters.
self._all_buffers_to(dtype=torch.float32)
self._cast_buffers(dtype=torch.float32)

if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
Expand All @@ -463,8 +536,8 @@ def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tenso
state_dict[k] = state_dict[k].cpu()

if self.mixed_precision:
# In case we are in mixed precision, restore buffers back to fp16.
self._all_buffers_to(dtype=self.buffer_dtype)
# In case we are in mixed precision, restore buffers back to buffer_dtype.
self._cast_buffers()
return state_dict

# TODO (Min): figuring out how to do typing for this overloaded function.
Expand Down Expand Up @@ -572,7 +645,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge
recurse (bool, Optional): recursively summon all params for nested
FSDP instances (default: True)
volatile (bool, Optional): if ``True``, modifications to params are
not guaranteed persist after the context manager exists;
not guaranteed to persist after the context manager exists;
enabling this can be slightly more efficient (default: False)
"""
if recurse:
Expand Down Expand Up @@ -625,6 +698,9 @@ def _reset_lazy_init(self) -> None:
self._queue_wait_for_post_backward_closure: Optional[Callable] = None
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None
for p in self.params:
if hasattr(p, "_fp32_shard"):
del p._fp32_shard # reset _init_param_attributes

def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right
Expand All @@ -642,12 +718,11 @@ def _lazy_init(self) -> None:
self._set_is_root()
self._setup_streams()

if self.cpu_offload: # Buffers stay on GPU, and don't get sharded
self._all_buffers_to(device=torch.device("cuda"), dtype=self.buffer_dtype)
else:
self._all_buffers_to(dtype=self.buffer_dtype)

if self._is_root:
# Buffers stay on GPU, and don't get sharded. Since _cast_buffers
# applies recursively, we only call this from the root instance.
self._cast_buffers()

# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
Expand Down Expand Up @@ -684,10 +759,6 @@ def _init_param_attributes(self, p: Parameter) -> None:
if hasattr(p, "_fp32_shard"):
return

# Compute device defaults to CUDA when *cpu_offload* is enabled, or the
# param's current device otherwise (could be CPU).
compute_device = torch.device("cuda") if self.cpu_offload else p.device

# A single shard of the parameters in full precision.
p._fp32_shard = p.data

Expand All @@ -707,7 +778,7 @@ def _init_param_attributes(self, p: Parameter) -> None:
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed.
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=compute_device, dtype=self.compute_dtype)
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
free_storage_(p._fp16_shard)
else:
p._fp16_shard = None # use _fp32_shard
Expand All @@ -720,7 +791,7 @@ def _init_param_attributes(self, p: Parameter) -> None:
# relevant computation.
if p._is_sharded:
p._full_param_padded = torch.zeros(
p.data.numel() * self.world_size, device=compute_device, dtype=self.compute_dtype
p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
)
free_storage_(p._full_param_padded)

Expand Down Expand Up @@ -1290,20 +1361,6 @@ def fn(x: torch.Tensor) -> torch.Tensor:
return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)


def cast_buffers_(
module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> None:
"""Cast all of module.named_buffers to device and floating point buffers to dtype."""
# if buffers are already on the right device and/or dtype this is just python loop cost
assert dtype in {torch.float32, torch.float16} # assumes compute_dtype == float16
for key, buf in module.named_buffers(recurse=False):
if buf is not None:
buf = buf.to(device=device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype)
setattr(module, key, buf)


def free_storage_(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
Expand Down
1 change: 1 addition & 0 deletions tests/ci_test_list_2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ tests/nn/pipe/test_deferred_batch_norm.py
tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py
tests/experimental/nn/test_multiprocess_pipe.py
tests/nn/data_parallel/test_fsdp_apply.py
myleott marked this conversation as resolved.
Show resolved Hide resolved
86 changes: 43 additions & 43 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,49 @@ def get_wrapped_model(group, cuda_first=False, config={}, **model_kwargs) -> Ful
model = FullyShardedDataParallel(TransformerWithSharedParams(group, **model_kwargs), group, **config).cuda()
return model

@classmethod
myleott marked this conversation as resolved.
Show resolved Hide resolved
def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
):
if config.get("mixed_precision", False):
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# this will cause the all-gather to happen in FP32, which is slower
# than necessary in most cases.
config["compute_dtype"] = torch.float32
else:
autocast = False

# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank, process_group=group
)
else:
model = ref_ddp_fn(model, group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()

# Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
shard_state_dict = model.state_dict()

try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")


class TestMixedPrecision(DistributedTest):
def test_all_fp32(self):
Expand Down Expand Up @@ -313,49 +356,6 @@ def test_mixture_of_experts_grad_clip_breaks(self):
def _dummy_ddp_fn(self, model, group):
return DummyDDP(model)

@classmethod
def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
):
if config.get("mixed_precision", False):
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# this will cause the all-gather to happen in FP32, which is slower
# than necessary in most cases.
config["compute_dtype"] = torch.float32
else:
autocast = False

# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank, process_group=group
)
else:
model = ref_ddp_fn(model, group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()

# Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
shard_state_dict = model.state_dict()

try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")

@parameterized.expand([[1], [inf]], name_func=rename_test)
def test_clip_norm_transformer(self, norm_type):
config = {"mixed_precision": True}
Expand Down
Loading