Skip to content

Commit

Permalink
feat(python): support find_unused_parameters on BaguaDDP (#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
shjwudp authored Dec 2, 2021
1 parent 600860f commit f5b4343
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 18 deletions.
57 changes: 43 additions & 14 deletions bagua/torch_api/data_parallel/bagua_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle
import collections
import logging
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Dict
from torch.nn.modules import Module

import bagua
Expand Down Expand Up @@ -34,6 +34,7 @@ def __init__(
process_group: BaguaProcessGroup,
bagua_module_name: Optional[str] = None,
gradient_as_bucket_view: bool = True,
find_unused_parameters: bool = False,
) -> None:
self.module = module
if bagua_module_name is None:
Expand All @@ -47,6 +48,7 @@ def __init__(
self.bagua_algorithm = algorithm.reify(process_group)
self.process_group = process_group
self.gradient_as_bucket_view = gradient_as_bucket_view
self.find_unused_parameters = find_unused_parameters
self.parameters_to_ignore = (
[]
) #: the parameter names to ignore during communication
Expand Down Expand Up @@ -90,6 +92,7 @@ class BaguaDistributedDataParallelStates:
self._speed_metrics_switch_on = env.get_autotune_level() >= 1
self._speed_metrics = StatisticalAverage()
self.require_backward_grad_sync = True
self.autograd_graph_params: Dict[str, torch.nn.Parameter] = {}

ddp = self

Expand Down Expand Up @@ -133,16 +136,22 @@ def record_speed_metrics_event(self, _):
torch.cuda.current_stream().record_event(start_event)
ddp._last_event_pair = (start_event, ddp._speed_metrics_end_event)

bagua_states._bagua_framework_hooks.extend([
self.module.register_forward_pre_hook(num_iteration_step_hook),
self.module.register_forward_pre_hook(algorithm_reset_hook),
self.module.register_forward_pre_hook(algorithm_forward_pre_hook),
self.module.register_forward_pre_hook(record_speed_metrics_event),
self.module.register_forward_pre_hook(autotune_hook),
self.module.register_forward_pre_hook(
clear_post_backward_callback_queued_hook
),
])
def clear_autograd_graph_params(self, _):
ddp.autograd_graph_params.clear()

bagua_states._bagua_framework_hooks.extend(
[
self.module.register_forward_pre_hook(clear_autograd_graph_params),
self.module.register_forward_pre_hook(num_iteration_step_hook),
self.module.register_forward_pre_hook(algorithm_reset_hook),
self.module.register_forward_pre_hook(algorithm_forward_pre_hook),
self.module.register_forward_pre_hook(record_speed_metrics_event),
self.module.register_forward_pre_hook(autotune_hook),
self.module.register_forward_pre_hook(
clear_post_backward_callback_queued_hook
),
]
)

# autotune service
self._bagua_autotune_client = get_hyperparameters_service_client()
Expand Down Expand Up @@ -170,6 +179,9 @@ def bagua_build_params(self) -> List[Tuple[str, torch.nn.Parameter]]:
]
]

if self.find_unused_parameters and len(self.autograd_graph_params) != 0:
modules_and_parameters = filter(lambda it: it[1][0] in self.autograd_graph_params, modules_and_parameters)

# Deduplicate any parameters that might be shared across child modules.
memo = set()
# "p not in memo" is the deduplication check.
Expand Down Expand Up @@ -381,15 +393,19 @@ def _bagua_autotune_get_buckets(self):
)
return list(recommended_buckets)

def _bagua_init_algorithm(self):
self._bagua_cleanup_algorithm()
self._bagua_broadcast_parameters()
def rebuild_buckets(self):
self._bagua_tensors = self.bagua_algorithm.init_tensors(self)
self._bagua_tensor_map = dict(
[(tensor.bagua_tensor_name, tensor) for tensor in self._bagua_tensors]
)
self._bagua_autotune_register_tensors()
self._bagua_reset_algorithm_buckets()
self.params_in_use = set([name for name, _ in self.bagua_build_params()])

def _bagua_init_algorithm(self):
self._bagua_cleanup_algorithm()
self._bagua_broadcast_parameters()
self.rebuild_buckets()

def _bagua_cleanup_algorithm(self):
bagua_states = self.module._bagua_states
Expand All @@ -399,6 +415,11 @@ def _bagua_cleanup_algorithm(self):

self.bagua_buckets.clear()

def delay_allreduce(self):
for param_name, parameter in self.bagua_build_params():
self.bagua_algorithm.init_backward_hook(self)(param_name, parameter)
self.bagua_algorithm.init_post_backward_hook(self)()

def _bagua_reset_algorithm_buckets(self):
bagua_states = self.module._bagua_states
self._bagua_cleanup_algorithm()
Expand All @@ -414,6 +435,9 @@ def real_hook(*unused):
if not self.require_backward_grad_sync:
return

if self.find_unused_parameters:
self.autograd_graph_params[param_name] = parameter

self.bagua_algorithm.init_backward_hook(self)(param_name, parameter)

def real_post_backward_hook(*unused):
Expand All @@ -423,6 +447,11 @@ def real_post_backward_hook(*unused):
self._speed_metrics_end_event
)

if self.find_unused_parameters:
if set(self.autograd_graph_params.keys()) != self.params_in_use:
self.rebuild_buckets()
self.delay_allreduce()

if not self._is_post_backward_callback_queued:
torch.autograd.Variable._execution_engine.queue_callback(
real_post_backward_hook
Expand Down
10 changes: 6 additions & 4 deletions bagua/torch_api/data_parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
bucket_cap_mb=25,
find_unused_parameters=False,
check_reduction=False,
gradient_as_bucket_view=False,
gradient_as_bucket_view=True,
# The following bagua parameters
optimizers: List[torch.optim.Optimizer] = [],
algorithm: "bagua.torch_api.algorithms.Algorithm" = GradientAllReduceAlgorithm(),
Expand Down Expand Up @@ -134,11 +134,13 @@ def __init__(
self.device = list(self.module.parameters())[0].device
assert broadcast_buffers is True, "Not yet supported"
self.broadcast_buffers = broadcast_buffers
assert find_unused_parameters is False, "Not yet supported"
self.find_unused_parameters = find_unused_parameters

self.inner = BaguaDistributedDataParallel(
self.module, optimizers, algorithm, to_bagua_process_group(process_group)
self.module, optimizers, algorithm,
process_group=to_bagua_process_group(process_group),
gradient_as_bucket_view=gradient_as_bucket_view,
find_unused_parameters=find_unused_parameters,
)

@property
Expand Down Expand Up @@ -221,7 +223,7 @@ def DistributedDataParallel(
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
gradient_as_bucket_view: bool = True,
# The followings are parameters for Bagua
optimizers: List[torch.optim.Optimizer] = [],
algorithm: "bagua.torch_api.algorithms.Algorithm" = GradientAllReduceAlgorithm()
Expand Down
133 changes: 133 additions & 0 deletions tests/torch_api/data_parallel/test_bagua_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os
import random
import sys
import unittest

import torch
import torch.distributed as c10d

if not c10d.is_available():
print("c10d not available, skipping tests", file=sys.stderr)
sys.exit(0)

import torch.nn.functional as F
from torch import nn
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
run_tests,
TEST_WITH_TSAN,
)
from . import test_c10d_common

import bagua.torch_api as bagua
from tests.internal.common_utils import find_free_port
from bagua.torch_api.data_parallel.distributed import DistributedDataParallel_V1_9_0 as DistributedDataParallel


@unittest.skipIf(
TEST_WITH_TSAN,
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
)
class DistributedDataParallelTest(test_c10d_common.AbstractDistributedDataParallelTest, MultiProcessTestCase):

def setUp(self):
super(DistributedDataParallelTest, self).setUp()
# NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
# that use NCCL_BLOCKING_WAIT will test it as expected.
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
os.environ.update(
{
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": str(find_free_port(8000, 8100)),
"BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)),
}
)
self._spawn_processes()

def _test_find_unused_parameters_kwarg(self, gradient_as_bucket_view=False):
"""
Note: this test can be sped up by only running it on a CPU module
once DistributedDataParallel supports them.
"""
torch.cuda.set_device(self.rank)
bagua.init_process_group()
process_group = c10d.distributed_c10d._get_default_group()

class FindUnusedParametersModule(nn.Module):
def __init__(self):
super(FindUnusedParametersModule, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False)
self.fc2 = nn.Linear(10, 4, bias=False)
self.fc3 = nn.Linear(4, 4, bias=False)
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
# Return the fc3 module so that the caller can invoke it
# outside of the forward function. While this is bad practice,
# we can use it to trigger a reducer error.
return (F.softmax(x, dim=1), self.fc3)

device_id = self.rank
batch_size = 4
criterion = nn.CrossEntropyLoss()
input = torch.rand([batch_size, 2], dtype=torch.float).to(device_id)
target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
device_id
)

ddp_model = None

def test_find_unused_parameters(
find_unused_parameters
):
model = DistributedDataParallel(
FindUnusedParametersModule().float().to(device_id),
device_ids=[device_id],
process_group=process_group,
find_unused_parameters=find_unused_parameters,
)
nonlocal ddp_model
ddp_model = model

output, fc3 = model(input)
# output = fc3(output)
loss = criterion(output, target)
loss.backward()

test_find_unused_parameters(find_unused_parameters=True)
bagua_build_params = [name for name,
_ in ddp_model.inner.bagua_build_params()]
self.assertEqual(set(bagua_build_params),
set(['fc1.weight', 'fc2.weight']))

test_find_unused_parameters(find_unused_parameters=False)
bagua_build_params = [name for name,
_ in ddp_model.inner.bagua_build_params()]
self.assertEqual(set(bagua_build_params), set(
['fc1.weight', 'fc2.weight', 'fc3.weight']))

@skip_if_lt_x_gpu(2)
def test_find_unused_parameters_kwarg_debug_detail(self):
os.environ.update(
{
"WORLD_SIZE": str(self.world_size),
"LOCAL_WORLD_SIZE": str(self.world_size),
"RANK": str(self.rank),
"LOCAL_RANK": str(self.rank),
}
)

self._test_find_unused_parameters_kwarg()


if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"

run_tests()

0 comments on commit f5b4343

Please sign in to comment.