diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 2d30d3ba5cafe..573a71856ae11 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -863,6 +863,19 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { Future& operator=(const Future&) = delete; Future& operator=(Future&&) = delete; + // Destructor + // Explicitly destroy events under device guard, otherwise it can lead to + // extra context being created on device 0. Reason: python garbage collector + // calls this destructor, but python GC does not have a device context, so a + // "default" one (usually on device 0) could be created when we go down the + // line of event destroy. + ~Future() override { + while (!events_.empty()) { + c10::OptionalDeviceGuard deviceGuard(events_.back().device()); + events_.pop_back(); + } + } + struct TORCH_API FutureError final : public std::exception { explicit FutureError(std::string&& error_msg_) : error_msg(std::move(error_msg_)) {} diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 92bb1671ca2b2..48675f20a1265 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -49,6 +49,7 @@ requires_nccl_version, skip_if_lt_x_gpu, skip_if_rocm_multiprocess, + sm_lower_than_70, TEST_SKIPS, with_dist_debug_levels, with_nccl_blocking_wait, @@ -431,9 +432,13 @@ def test_nan_rank_filter(self): @skip_if_lt_x_gpu(2) def test_nan_check(self): # Not expecting an error, NaN check should not make legit code fail + device = torch.device("cuda:%d" % self.rank) + # Test needs sm_70, see #135273, #137161 + if sm_lower_than_70(device): + return + os.environ["TORCH_NCCL_NAN_CHECK"] = "1" store = c10d.FileStore(self.file_name, self.world_size) - device = torch.device("cuda:%d" % self.rank) c10d.init_process_group( backend="nccl", store=store, rank=self.rank, world_size=self.world_size ) @@ -446,6 +451,95 @@ def test_nan_check(self): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + def _helper_test_extra_cuda_context_by_nvml(self): + """ + A helper for `test_extra_cuda_context`, if pynvml is avaiable. + pynvml provides python bindings for NVIDIA NVML functionalities. + Here we are interested in: nvmlDeviceGetComputeRunningProcesses + """ + import pynvml + + pynvml.nvmlInit() + + device = torch.device("cuda:%d" % self.rank) + x = torch.empty((1,), device=device) + work = c10d.all_reduce(x, async_op=True) + + # Wait for non-0 ranks to garbage collect Work -- this is the latest + # point where extra CUDA context can be created + if self.rank == 0: + time.sleep(5) + del work + handle = pynvml.nvmlDeviceGetHandleByIndex(self.rank) + processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + nprocs = len(processes) + + # A barrier for non-0 ranks + c10d.all_reduce(x) + torch.cuda.synchronize(device) + c10d.destroy_process_group() + self.assertEqual( + nprocs, + 1, + f"Found {nprocs} processes creating contexts on {device}, expecting 1 only", + ) + + def _helper_test_extra_cuda_context_by_memory(self): + """ + A helper for `test_extra_cuda_context`, if pynvml is NOT avaiable. + If extra context is created, it would manifest into device 0's memory usage. + """ + device = torch.device("cuda:%d" % self.rank) + x = torch.empty((1,), device=device) + # Rank 0 takes a snapshot before collective -- this snapshot should have + # included rank 0's own context. + if self.rank == 0: + free, total = torch.cuda.mem_get_info(device) + used_before = float(total - free) + + work = c10d.all_reduce(x, async_op=True) + + # Wait for non-0 ranks to garbage collect Work -- this is the latest + # point where extra CUDA context can be created + if self.rank == 0: + time.sleep(5) + free, total = torch.cuda.mem_get_info(device) + used_after = float(total - free) + del work + + # A barrier for non-0 ranks + c10d.all_reduce(x) + torch.cuda.synchronize(device) + c10d.destroy_process_group() + if self.rank == 0: + # If non-0 rank creates a context on device 0, this assert would + # fail because one context takes about 1 GB -- much more than the + # tensor size created in this test. + self.assertTrue( + used_after < used_before * 1.5, + f"{device} used {used_after} bytes after collective, " + f"50% more than the status before ({used_before} bytes). " + f"Extra CUDA context may have been created.", + ) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_extra_cuda_context(self): + # Check if non-0 ranks would create extra CUDA context on device 0 + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device("cuda:%d" % self.rank) + c10d.init_process_group( + backend="nccl", + store=store, + rank=self.rank, + world_size=self.world_size, + device_id=device, + ) + try: + self._helper_test_extra_cuda_context_by_nvml() + except ModuleNotFoundError: + self._helper_test_extra_cuda_context_by_memory() + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): @@ -2811,12 +2905,16 @@ def test_all_reduce_coalesced_nccl_float8_errors(self): @requires_nccl() @skip_if_lt_x_gpu(2) def test_all_reduce_coalesced_manager_nccl(self): + device = torch.device("cuda:%d" % self.rank) + # Test needs sm_70, see #135273, #137161 + if sm_lower_than_70(device): + return + store = c10d.FileStore(self.file_name, self.world_size) c10d.init_process_group( backend="nccl", store=store, rank=self.rank, world_size=self.world_size ) process_group = c10d.distributed_c10d._get_default_group() - device = torch.device("cuda:%d" % self.rank) tensors = [ torch.full((60 + i,), self.rank + 1 + i, device=device, dtype=torch.float) for i in range(5) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 400975c7cd758..83bae62684fb7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2736,6 +2736,11 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( avoidRecordStreams |= avoidRecordStreams_; nanCheck &= enableNanCheck_; + auto device = getDevice(inputs[0]); + // Guard must be created before `currentStreamCaptureStatusMayInitCtx`; + // otherwise, extra CUDA context could be created on device 0. + at::cuda::OptionalCUDAGuard gpuGuard(device); + c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2746,7 +2751,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( } op_id_++; - auto device = getDevice(inputs[0]); const auto key = getKeyFromDevice(device); auto ncclComm = getNCCLComm(key, device, opType); @@ -2788,8 +2792,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard(device); - if (nanCheck) { for (const auto& input : inputs) { checkForNan(input, ncclStream); @@ -2914,6 +2916,19 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( bool avoidRecordStreams) { // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; + + // Currently, the API permits one scenario where inputs.size() and + // outputs.size() are > 0. + // 1. If the call was a _coalesced call, all inputs must be on the same + // device. + // The group of nccl calls applies the collective separately to each input, + // but the group as a whole should be efficient, and might even execute as + // a single fused kernel. + auto device = getDevice(inputs[0]); + // Guard must be created before `currentStreamCaptureStatusMayInitCtx`; + // otherwise, extra CUDA context could be created on device 0. + at::cuda::OptionalCUDAGuard gpuGuard(device); + c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2928,14 +2943,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // op_id_ once per indvidual operation within the group op_id_++; - // Currently, the API permits one scenario where inputs.size() and - // outputs.size() are > 0. - // 1. If the call was a _coalesced call, all inputs must be on the same - // device. - // The group of nccl calls applies the collective separately to each input, - // but the group as a whole should be efficient, and might even execute as - // a single fused kernel. - auto device = getDevice(inputs[0]); const auto key = getKeyFromDevice(device); auto ncclComm = getNCCLComm(key, device, opType); @@ -2978,8 +2985,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard(device); - // Start event should only be recorded before the ncclGroupStart() (which // happens inside AutoNcclGroup guard below) if (work->timingEnabled_) { @@ -3134,6 +3139,8 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } auto device = getDevice(tensor); + at::cuda::OptionalCUDAGuard gpuGuard(device); + std::string key; int p2pRank = 0, p2pTargetRank = 0; bool isSendRecvSelf = false; @@ -3255,9 +3262,6 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( /*isP2P=*/true); } - // is gpuGuard needed for the if block below, or can i swap them - at::cuda::OptionalCUDAGuard gpuGuard(device); - // Only check for NaN for send ops, for recv ops `tensor` can be a random // placeholder if (enableNanCheck_ && opType == OpType::SEND) { diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index fb2a5c034b3e7..01d76fff93ff5 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -353,6 +353,11 @@ def skip_if_win32(): ) +def sm_lower_than_70(device: torch.device): + """Returns True if the device's compute capability is lower than 70""" + return torch.cuda.get_device_capability(device) < (7, 0) and not torch.version.hip + + @retry_on_connect_failures def create_tcp_store( addr="localhost", diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index bbbec92153df5..e5744df320b4e 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -75,6 +75,7 @@ with_dist_debug_levels, verify_ddp_error_logged, DistTestCases, + sm_lower_than_70, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -3506,6 +3507,11 @@ def test_all_gather_full_group(self): ) @skip_if_no_gpu def test_all_gather_v_cuda(self): + device = torch.device("cuda:%d" % self.rank) + # Test needs sm_70, see #135273, #137161 + if sm_lower_than_70(device): + return + self._barrier() group, group_id, rank = self._init_global_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) @@ -8926,6 +8932,11 @@ def test_monitored_barrier_wait_all_ranks(self): @with_dist_debug_levels(levels=["INFO"]) @skip_if_lt_x_gpu(2) def test_ddp_build_debug_param_to_name_mapping(self): + device = torch.device("cuda:%d" % self.rank) + # Test needs sm_70, see #135273, #137161 + if sm_lower_than_70(device): + return + model = TwoLinLayerNet() net = torch.nn.parallel.DistributedDataParallel( model.cuda(self.rank),