Skip to content

Commit

Permalink
[Distributed] Fix extra context on device 0 (pytorch#135273)
Browse files Browse the repository at this point in the history
This PR contains multiple fixes for issue pytorch#135279:

## First part:
Moves the GPU guard (`cudaSetDevice`) before the `currentStreamCaptureStatusMayInitCtx` call.
As its name suggests, it May Init Ctx.

## Second part:
Even with the above fix, additional contexts are still observed during Work object destruction, e.g.
```
work = dist.all_reduce(tensor, async_op=True)
time.sleep(5)  <-- no additional context yet
del work  <-- additional context shows up
```
### Debug process
Chasing it down to destruction of a `Future` object -- a member variable of `Work`.
Then further down to the following member of `Future`:
```
std::vector<c10::Event> events_;
```
When the `events_` are destroyed, we hit the road down to:
https://github.com/pytorch/pytorch/blob/1f3a79379012b408e0375e81fe9205dcba5e34ba/c10/cuda/impl/CUDAGuardImpl.h#L106-L121

When there is no "preset" CUDA context (**which is the case for python garbage collector**), line 112: `c10::cuda::GetDevice(&orig_device)` will set `orig_device` to 0. Then, at line 120, `c10::cuda::SetDevice(orig_device)` will "officially" set the context to device 0 --
**that's where rank 1, 2, ... can create extra context on device 0!**
### Solution
This PR adds an explicit destructor to `Future`. In this destructor, destroy each event with a device guard.

## Test
Added test_extra_cuda_context, implemented via
- `pynvml` (if available), or
- memory consumption check.

`python test/distributed/test_c10d_nccl.py -k test_extra_cuda_context`

Pull Request resolved: pytorch#135273
Approved by: https://github.com/fduwjj, https://github.com/wconstab, https://github.com/eqy
ghstack dependencies: pytorch#137161
  • Loading branch information
kwen2501 authored and pytorchmergebot committed Oct 10, 2024
1 parent 9690cac commit cdd8fa9
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 18 deletions.
13 changes: 13 additions & 0 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,19 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
Future& operator=(const Future&) = delete;
Future& operator=(Future&&) = delete;

// Destructor
// Explicitly destroy events under device guard, otherwise it can lead to
// extra context being created on device 0. Reason: python garbage collector
// calls this destructor, but python GC does not have a device context, so a
// "default" one (usually on device 0) could be created when we go down the
// line of event destroy.
~Future() override {
while (!events_.empty()) {
c10::OptionalDeviceGuard deviceGuard(events_.back().device());
events_.pop_back();
}
}

struct TORCH_API FutureError final : public std::exception {
explicit FutureError(std::string&& error_msg_)
: error_msg(std::move(error_msg_)) {}
Expand Down
102 changes: 100 additions & 2 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
requires_nccl_version,
skip_if_lt_x_gpu,
skip_if_rocm_multiprocess,
sm_lower_than_70,
TEST_SKIPS,
with_dist_debug_levels,
with_nccl_blocking_wait,
Expand Down Expand Up @@ -431,9 +432,13 @@ def test_nan_rank_filter(self):
@skip_if_lt_x_gpu(2)
def test_nan_check(self):
# Not expecting an error, NaN check should not make legit code fail
device = torch.device("cuda:%d" % self.rank)
# Test needs sm_70, see #135273, #137161
if sm_lower_than_70(device):
return

os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device("cuda:%d" % self.rank)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
Expand All @@ -446,6 +451,95 @@ def test_nan_check(self):
# reset env
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"

def _helper_test_extra_cuda_context_by_nvml(self):
"""
A helper for `test_extra_cuda_context`, if pynvml is avaiable.
pynvml provides python bindings for NVIDIA NVML functionalities.
Here we are interested in: nvmlDeviceGetComputeRunningProcesses
"""
import pynvml

pynvml.nvmlInit()

device = torch.device("cuda:%d" % self.rank)
x = torch.empty((1,), device=device)
work = c10d.all_reduce(x, async_op=True)

# Wait for non-0 ranks to garbage collect Work -- this is the latest
# point where extra CUDA context can be created
if self.rank == 0:
time.sleep(5)
del work
handle = pynvml.nvmlDeviceGetHandleByIndex(self.rank)
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
nprocs = len(processes)

# A barrier for non-0 ranks
c10d.all_reduce(x)
torch.cuda.synchronize(device)
c10d.destroy_process_group()
self.assertEqual(
nprocs,
1,
f"Found {nprocs} processes creating contexts on {device}, expecting 1 only",
)

def _helper_test_extra_cuda_context_by_memory(self):
"""
A helper for `test_extra_cuda_context`, if pynvml is NOT avaiable.
If extra context is created, it would manifest into device 0's memory usage.
"""
device = torch.device("cuda:%d" % self.rank)
x = torch.empty((1,), device=device)
# Rank 0 takes a snapshot before collective -- this snapshot should have
# included rank 0's own context.
if self.rank == 0:
free, total = torch.cuda.mem_get_info(device)
used_before = float(total - free)

work = c10d.all_reduce(x, async_op=True)

# Wait for non-0 ranks to garbage collect Work -- this is the latest
# point where extra CUDA context can be created
if self.rank == 0:
time.sleep(5)
free, total = torch.cuda.mem_get_info(device)
used_after = float(total - free)
del work

# A barrier for non-0 ranks
c10d.all_reduce(x)
torch.cuda.synchronize(device)
c10d.destroy_process_group()
if self.rank == 0:
# If non-0 rank creates a context on device 0, this assert would
# fail because one context takes about 1 GB -- much more than the
# tensor size created in this test.
self.assertTrue(
used_after < used_before * 1.5,
f"{device} used {used_after} bytes after collective, "
f"50% more than the status before ({used_before} bytes). "
f"Extra CUDA context may have been created.",
)

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_extra_cuda_context(self):
# Check if non-0 ranks would create extra CUDA context on device 0
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device("cuda:%d" % self.rank)
c10d.init_process_group(
backend="nccl",
store=store,
rank=self.rank,
world_size=self.world_size,
device_id=device,
)
try:
self._helper_test_extra_cuda_context_by_nvml()
except ModuleNotFoundError:
self._helper_test_extra_cuda_context_by_memory()

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_destruct_before_terminate_pg(self):
Expand Down Expand Up @@ -2811,12 +2905,16 @@ def test_all_reduce_coalesced_nccl_float8_errors(self):
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_all_reduce_coalesced_manager_nccl(self):
device = torch.device("cuda:%d" % self.rank)
# Test needs sm_70, see #135273, #137161
if sm_lower_than_70(device):
return

store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
process_group = c10d.distributed_c10d._get_default_group()
device = torch.device("cuda:%d" % self.rank)
tensors = [
torch.full((60 + i,), self.rank + 1 + i, device=device, dtype=torch.float)
for i in range(5)
Expand Down
36 changes: 20 additions & 16 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2736,6 +2736,11 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
avoidRecordStreams |= avoidRecordStreams_;
nanCheck &= enableNanCheck_;

auto device = getDevice(inputs[0]);
// Guard must be created before `currentStreamCaptureStatusMayInitCtx`;
// otherwise, extra CUDA context could be created on device 0.
at::cuda::OptionalCUDAGuard gpuGuard(device);

c10::cuda::CaptureStatus capture_status =
c10::cuda::currentStreamCaptureStatusMayInitCtx();
errorIfCapturingNonCapturableNCCL(capture_status);
Expand All @@ -2746,7 +2751,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
}
op_id_++;

auto device = getDevice(inputs[0]);
const auto key = getKeyFromDevice(device);
auto ncclComm = getNCCLComm(key, device, opType);

Expand Down Expand Up @@ -2788,8 +2792,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
std::make_shared<std::vector<at::Tensor>>(inputs);
}

at::cuda::OptionalCUDAGuard gpuGuard(device);

if (nanCheck) {
for (const auto& input : inputs) {
checkForNan(input, ncclStream);
Expand Down Expand Up @@ -2914,6 +2916,19 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
bool avoidRecordStreams) {
// Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;

// Currently, the API permits one scenario where inputs.size() and
// outputs.size() are > 0.
// 1. If the call was a _coalesced call, all inputs must be on the same
// device.
// The group of nccl calls applies the collective separately to each input,
// but the group as a whole should be efficient, and might even execute as
// a single fused kernel.
auto device = getDevice(inputs[0]);
// Guard must be created before `currentStreamCaptureStatusMayInitCtx`;
// otherwise, extra CUDA context could be created on device 0.
at::cuda::OptionalCUDAGuard gpuGuard(device);

c10::cuda::CaptureStatus capture_status =
c10::cuda::currentStreamCaptureStatusMayInitCtx();
errorIfCapturingNonCapturableNCCL(capture_status);
Expand All @@ -2928,14 +2943,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
// op_id_ once per indvidual operation within the group
op_id_++;

// Currently, the API permits one scenario where inputs.size() and
// outputs.size() are > 0.
// 1. If the call was a _coalesced call, all inputs must be on the same
// device.
// The group of nccl calls applies the collective separately to each input,
// but the group as a whole should be efficient, and might even execute as
// a single fused kernel.
auto device = getDevice(inputs[0]);
const auto key = getKeyFromDevice(device);
auto ncclComm = getNCCLComm(key, device, opType);

Expand Down Expand Up @@ -2978,8 +2985,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
std::make_shared<std::vector<at::Tensor>>(inputs);
}

at::cuda::OptionalCUDAGuard gpuGuard(device);

// Start event should only be recorded before the ncclGroupStart() (which
// happens inside AutoNcclGroup guard below)
if (work->timingEnabled_) {
Expand Down Expand Up @@ -3134,6 +3139,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
}

auto device = getDevice(tensor);
at::cuda::OptionalCUDAGuard gpuGuard(device);

std::string key;
int p2pRank = 0, p2pTargetRank = 0;
bool isSendRecvSelf = false;
Expand Down Expand Up @@ -3255,9 +3262,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
/*isP2P=*/true);
}

// is gpuGuard needed for the if block below, or can i swap them
at::cuda::OptionalCUDAGuard gpuGuard(device);

// Only check for NaN for send ops, for recv ops `tensor` can be a random
// placeholder
if (enableNanCheck_ && opType == OpType::SEND) {
Expand Down
5 changes: 5 additions & 0 deletions torch/testing/_internal/common_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@ def skip_if_win32():
)


def sm_lower_than_70(device: torch.device):
"""Returns True if the device's compute capability is lower than 70"""
return torch.cuda.get_device_capability(device) < (7, 0) and not torch.version.hip


@retry_on_connect_failures
def create_tcp_store(
addr="localhost",
Expand Down
11 changes: 11 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
with_dist_debug_levels,
verify_ddp_error_logged,
DistTestCases,
sm_lower_than_70,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
Expand Down Expand Up @@ -3506,6 +3507,11 @@ def test_all_gather_full_group(self):
)
@skip_if_no_gpu
def test_all_gather_v_cuda(self):
device = torch.device("cuda:%d" % self.rank)
# Test needs sm_70, see #135273, #137161
if sm_lower_than_70(device):
return

self._barrier()
group, group_id, rank = self._init_global_test()
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
Expand Down Expand Up @@ -8926,6 +8932,11 @@ def test_monitored_barrier_wait_all_ranks(self):
@with_dist_debug_levels(levels=["INFO"])
@skip_if_lt_x_gpu(2)
def test_ddp_build_debug_param_to_name_mapping(self):
device = torch.device("cuda:%d" % self.rank)
# Test needs sm_70, see #135273, #137161
if sm_lower_than_70(device):
return

model = TwoLinLayerNet()
net = torch.nn.parallel.DistributedDataParallel(
model.cuda(self.rank),
Expand Down

0 comments on commit cdd8fa9

Please sign in to comment.