From 4a8e49389c33934234dc89616fd17a58e760e2e7 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Sun, 13 Oct 2024 13:12:42 +0800 Subject: [PATCH] Make Context to be Device-agnostic Step by Step (1/N) (#136519) ---- - make init to be device-agnostic and move it to AcceleratorHooksInterface - refactoring context related to device initialization Pull Request resolved: https://github.com/pytorch/pytorch/pull/136519 Approved by: https://github.com/ezyang, https://github.com/EikanWang, https://github.com/guangyey --- aten/src/ATen/Context.h | 77 ++++++++----------- aten/src/ATen/cuda/EmptyTensor.cpp | 4 +- aten/src/ATen/cuda/PeerToPeerAccess.cpp | 2 +- aten/src/ATen/cuda/detail/CUDAHooks.cpp | 2 +- aten/src/ATen/cuda/detail/CUDAHooks.h | 2 +- .../ATen/detail/AcceleratorHooksInterface.h | 4 + aten/src/ATen/detail/CUDAHooksInterface.h | 2 +- aten/src/ATen/detail/HIPHooksInterface.h | 5 +- aten/src/ATen/detail/IPUHooksInterface.h | 15 +++- aten/src/ATen/detail/MAIAHooksInterface.h | 15 +++- aten/src/ATen/detail/MPSHooksInterface.h | 2 +- aten/src/ATen/detail/MTIAHooksInterface.h | 7 +- .../ATen/detail/PrivateUse1HooksInterface.h | 2 +- aten/src/ATen/detail/XPUHooksInterface.h | 6 +- aten/src/ATen/mps/MPSHooks.h | 2 +- aten/src/ATen/mps/MPSHooks.mm | 2 +- aten/src/ATen/native/cuda/Resize.cpp | 2 +- aten/src/ATen/test/cuda_cub_test.cu | 8 +- aten/src/ATen/xpu/detail/XPUHooks.cpp | 2 +- aten/src/ATen/xpu/detail/XPUHooks.h | 2 +- test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp | 2 +- test/cpp/c10d/ProcessGroupNCCLTest.cpp | 2 +- test/cpp_extensions/mtia_extension.cpp | 2 +- torch/csrc/autograd/VariableTypeManual.cpp | 4 +- torch/csrc/cuda/Module.cpp | 4 +- torch/csrc/cuda/memory_snapshot.cpp | 4 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 5 +- .../jit/mobile/register_ops_common_utils.h | 2 +- torch/csrc/mtia/Module.cpp | 2 +- torch/csrc/xpu/Module.cpp | 2 +- torchgen/dest/register_dispatch_key.py | 4 +- 31 files changed, 108 insertions(+), 88 deletions(-) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 8b1bd5689ad06..a23da72c91874 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -39,8 +39,8 @@ class TORCH_API Context { const Generator& defaultGenerator(Device device) { c10::DeviceType device_type = device.type(); - initCUDAIfNeeded(device_type); - initHIPIfNeeded(device_type); + lazyInitDevice(device_type); + if (device_type == at::kCPU) { return at::detail::getDefaultCPUGenerator(); } else if (device_type == at::kCUDA) { @@ -58,6 +58,7 @@ class TORCH_API Context { AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled."); } } + const AcceleratorHooksInterface& getAcceleratorHooksInterface( std::optional opt_device_type = std::nullopt) { c10::DeviceType device_type = opt_device_type.has_value() @@ -80,16 +81,17 @@ class TORCH_API Context { c10::DeviceTypeName(device_type), " device type not an accelerator."); } } + Device getDeviceFromPtr(void* data, c10::DeviceType device_type) { - initCUDAIfNeeded(device_type); - initHIPIfNeeded(device_type); - initXPUIfNeeded(device_type); + lazyInitDevice(device_type); + if (device_type == at::kCPU) { return c10::DeviceType::CPU; } else { return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data); } } + bool isPinnedPtr( const void* data, std::optional device_type = std::nullopt) { @@ -102,10 +104,20 @@ class TORCH_API Context { } return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data); } + Allocator* getPinnedMemoryAllocator( std::optional device_type = std::nullopt) { return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator(); } + + void lazyInitDevice(c10::DeviceType device_type) { + if (device_type != at::kCPU) { + c10::call_once(init_[static_cast(device_type)], [&] { + getAcceleratorHooksInterface(device_type).init(); + }); + } + } + static bool hasOpenMP(); static bool hasMKL(); static bool hasLAPACK(); @@ -158,27 +170,6 @@ class TORCH_API Context { static bool hasMAIA() { return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA); } - // defined in header so that getNonVariableType has ability to inline - // call_once check. getNonVariableType is called fairly frequently - void lazyInitCUDA() { - c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); }); - } - void lazyInitHIP() { - c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); }); - } - void lazyInitXPU() { - c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); }); - } - void lazyInitMTIA() { - c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); }); - } - void lazyInitPrivateUse1() { - c10::call_once(thp_init, [&] { - if (isPrivateUse1HooksRegistered()) { - at::detail::getPrivateUse1Hooks().initPrivateUse1(); - } - }); - } static const at::cuda::NVRTC& getNVRTC() { return detail::getCUDAHooks().nvrtc(); } @@ -353,28 +344,26 @@ class TORCH_API Context { bool allowFP16ReductionCPU() const; void setAllowFP16ReductionCPU(bool); - private: - void initCUDAIfNeeded(c10::DeviceType p) { - if (p == c10::DeviceType::CUDA) { - lazyInitCUDA(); - } + // Preserved for BC + void lazyInitCUDA() { + lazyInitDevice(at::kCUDA); } - void initHIPIfNeeded(c10::DeviceType p) { - if (p == c10::DeviceType::HIP) { - lazyInitHIP(); - } + void lazyInitHIP() { + lazyInitDevice(at::kHIP); } - void initXPUIfNeeded(c10::DeviceType p) { - if (p == c10::DeviceType::XPU) { - lazyInitXPU(); - } + void lazyInitXPU() { + lazyInitDevice(at::kXPU); + } + void lazyInitMTIA() { + lazyInitDevice(at::kMTIA); } + void lazyInitPrivateUse1() { + lazyInitDevice(at::kPrivateUse1); + } + + private: static bool checkCuBLASConfigDeterministic(); - c10::once_flag thc_init; - c10::once_flag thh_init; - c10::once_flag thx_init; - c10::once_flag th_mtia_init; - c10::once_flag thp_init; + std::array init_; bool enabled_cudnn = true; bool deterministic_cudnn = false; bool deterministic_mkldnn = false; diff --git a/aten/src/ATen/cuda/EmptyTensor.cpp b/aten/src/ATen/cuda/EmptyTensor.cpp index ad4f854a05ccc..108b7be47de17 100644 --- a/aten/src/ATen/cuda/EmptyTensor.cpp +++ b/aten/src/ATen/cuda/EmptyTensor.cpp @@ -10,7 +10,7 @@ TensorBase empty_cuda( ScalarType dtype, std::optional device_opt, std::optional memory_format_opt) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); const auto device = device_or_default(device_opt); TORCH_INTERNAL_ASSERT(device.is_cuda()); const DeviceGuard device_guard(device); @@ -50,7 +50,7 @@ TensorBase empty_strided_cuda( IntArrayRef stride, ScalarType dtype, std::optional device_opt) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); const auto device = device_or_default(device_opt); TORCH_INTERNAL_ASSERT(device.is_cuda()); const DeviceGuard device_guard(device); diff --git a/aten/src/ATen/cuda/PeerToPeerAccess.cpp b/aten/src/ATen/cuda/PeerToPeerAccess.cpp index e9ce2d9d3a604..e56d2f3ee229d 100644 --- a/aten/src/ATen/cuda/PeerToPeerAccess.cpp +++ b/aten/src/ATen/cuda/PeerToPeerAccess.cpp @@ -34,7 +34,7 @@ void init_p2p_access_cache(int64_t num_devices) { } // namespace detail bool get_p2p_access(int dev, int dev_to_access) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); TORCH_CHECK(dev >= 0 || dev < num_devices_, dev, " is not a device"); diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 0bdd865d88d25..24c54bab294ed 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -84,7 +84,7 @@ struct _Initializer { // NB: deleter is dynamic, because we need it to live in a separate // compilation unit (alt is to have another method in hooks, but // let's not if we don't need to!) -void CUDAHooks::initCUDA() const { +void CUDAHooks::init() const { C10_LOG_API_USAGE_ONCE("aten.init.cuda"); // Force the update to enable unit testing. This code get executed before unit tests // have a chance to enable vitals. diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index 11401701e44c0..9187a80317077 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -19,7 +19,7 @@ TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)()); // The real implementation of CUDAHooksInterface struct CUDAHooks : public at::CUDAHooksInterface { CUDAHooks(at::CUDAHooksArgs) {} - void initCUDA() const override; + void init() const override; Device getDeviceFromPtr(void* data) const override; bool isPinnedPtr(const void* data) const override; const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override; diff --git a/aten/src/ATen/detail/AcceleratorHooksInterface.h b/aten/src/ATen/detail/AcceleratorHooksInterface.h index 0f97e03090405..4eab4d24f71b3 100644 --- a/aten/src/ATen/detail/AcceleratorHooksInterface.h +++ b/aten/src/ATen/detail/AcceleratorHooksInterface.h @@ -19,6 +19,10 @@ struct TORCH_API AcceleratorHooksInterface { // Whether the device at device_index is fully initialized or not. virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0; + virtual void init() const { + TORCH_CHECK(false, "Backend doesn`t support init()"); + } + virtual DeviceIndex deviceCount() const { return 0; } diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index fe29a2d702b70..fdba8b830af4b 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -65,7 +65,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { ~CUDAHooksInterface() override = default; // Initialize THCState and, transitively, the CUDA state - virtual void initCUDA() const { + void init() const override { TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP); } diff --git a/aten/src/ATen/detail/HIPHooksInterface.h b/aten/src/ATen/detail/HIPHooksInterface.h index b3194668d9512..fe46bfc5b854f 100644 --- a/aten/src/ATen/detail/HIPHooksInterface.h +++ b/aten/src/ATen/detail/HIPHooksInterface.h @@ -26,9 +26,8 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface { // squelch -Werror=non-virtual-dtor ~HIPHooksInterface() override = default; - // Initialize the HIP library state - virtual void initHIP() const { - AT_ERROR("Cannot initialize HIP without ATen_hip library."); + void init() const override { + TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library."); } virtual std::unique_ptr initHIPGenerator(Context*) const { diff --git a/aten/src/ATen/detail/IPUHooksInterface.h b/aten/src/ATen/detail/IPUHooksInterface.h index 8f24df4fdd2de..20dbb703d571f 100644 --- a/aten/src/ATen/detail/IPUHooksInterface.h +++ b/aten/src/ATen/detail/IPUHooksInterface.h @@ -1,14 +1,25 @@ #pragma once #include +#include + #include #include #include namespace at { -struct TORCH_API IPUHooksInterface { - virtual ~IPUHooksInterface() = default; +struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface { + ~IPUHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + } + + bool hasPrimaryContext(DeviceIndex device_index) const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + return false; + } virtual const Generator& getDefaultIPUGenerator( DeviceIndex device_index [[maybe_unused]] = -1) const { diff --git a/aten/src/ATen/detail/MAIAHooksInterface.h b/aten/src/ATen/detail/MAIAHooksInterface.h index ad4ef146eccd9..554cc93043fd3 100644 --- a/aten/src/ATen/detail/MAIAHooksInterface.h +++ b/aten/src/ATen/detail/MAIAHooksInterface.h @@ -3,13 +3,24 @@ #include #include +#include + // NB: Class must live in `at` due to limitations of Registry.h. namespace at { -struct TORCH_API MAIAHooksInterface { +struct TORCH_API MAIAHooksInterface : AcceleratorHooksInterface { // This should never actually be implemented, but it is used to // squelch -Werror=non-virtual-dtor - virtual ~MAIAHooksInterface() = default; + ~MAIAHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library."); + } + + bool hasPrimaryContext(DeviceIndex device_index) const override { + TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library."); + return false; + } virtual std::string showConfig() const { TORCH_CHECK(false, "Cannot query detailed MAIA version information."); diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index 180ff68588edd..e3f8d3132bb8c 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -22,7 +22,7 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { ~MPSHooksInterface() override = default; // Initialize the MPS library state - virtual void initMPS() const { + void init() const override { FAIL_MPSHOOKS_FUNC(__func__); } virtual bool hasMPS() const { diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index 1480436fb4f1d..ca9ad432a3e20 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -31,7 +31,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { ~MTIAHooksInterface() override = default; - virtual void initMTIA() const { + void init() const override { // Avoid logging here, since MTIA needs init devices first then it will know // how many devices are available. Make it as no-op if mtia extension is not // dynamically loaded. @@ -109,6 +109,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { FAIL_MTIAHOOKS_FUNC(__func__); return nullptr; } + + // Perserved for BC + virtual void initMTIA() const { + return; + } }; struct TORCH_API MTIAHooksArgs {}; diff --git a/aten/src/ATen/detail/PrivateUse1HooksInterface.h b/aten/src/ATen/detail/PrivateUse1HooksInterface.h index 3a567ae32f8e3..3820c960dfe57 100644 --- a/aten/src/ATen/detail/PrivateUse1HooksInterface.h +++ b/aten/src/ATen/detail/PrivateUse1HooksInterface.h @@ -40,7 +40,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`."); } - virtual void initPrivateUse1() const {} + void init() const override {} virtual void resizePrivateUse1Bytes( const c10::Storage& storage, size_t newsize) const { diff --git a/aten/src/ATen/detail/XPUHooksInterface.h b/aten/src/ATen/detail/XPUHooksInterface.h index 9d349102d38bd..2c6fa723e2dfa 100644 --- a/aten/src/ATen/detail/XPUHooksInterface.h +++ b/aten/src/ATen/detail/XPUHooksInterface.h @@ -14,10 +14,8 @@ namespace at { struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{ ~XPUHooksInterface() override = default; - virtual void initXPU() const { - TORCH_CHECK( - false, - "Cannot initialize XPU without ATen_xpu library."); + void init() const override { + TORCH_CHECK(false, "Cannot initialize XPU without ATen_xpu library."); } virtual bool hasXPU() const { diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index 4858c0609f56b..20662be436910 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -12,7 +12,7 @@ namespace at::mps { // The real implementation of MPSHooksInterface struct MPSHooks : public at::MPSHooksInterface { MPSHooks(at::MPSHooksArgs) {} - void initMPS() const override; + void init() const override; // MPSDevice interface bool hasMPS() const override; diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index 5855e16aca8c9..983bb516a31b8 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -10,7 +10,7 @@ namespace at::mps { -void MPSHooks::initMPS() const { +void MPSHooks::init() const { C10_LOG_API_USAGE_ONCE("aten.init.mps"); // TODO: initialize MPS devices and streams here } diff --git a/aten/src/ATen/native/cuda/Resize.cpp b/aten/src/ATen/native/cuda/Resize.cpp index c11dd8dcc960e..e6f050603c641 100644 --- a/aten/src/ATen/native/cuda/Resize.cpp +++ b/aten/src/ATen/native/cuda/Resize.cpp @@ -30,7 +30,7 @@ void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes) { c10::cuda::CUDAGuard guard(device.index()); at::DataPtr data = allocator->allocate(size_bytes); if (storage->data_ptr()) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); C10_CUDA_CHECK( cudaMemcpyAsync( diff --git a/aten/src/ATen/test/cuda_cub_test.cu b/aten/src/ATen/test/cuda_cub_test.cu index 9041ef70cedb6..5e5e25d2a8c90 100644 --- a/aten/src/ATen/test/cuda_cub_test.cu +++ b/aten/src/ATen/test/cuda_cub_test.cu @@ -138,7 +138,9 @@ __managed__ int input[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; TEST(InclusiveScanSplit, CubTest) { if (!at::cuda::is_available()) return; - at::globalContext().lazyInitCUDA(); // This is required to use PyTorch's caching allocator. + at::globalContext().lazyInitDevice( + c10::DeviceType::CUDA); // This is required to use PyTorch's caching + // allocator. int *output1; cudaMallocManaged(&output1, sizeof(int) * 10); @@ -162,7 +164,9 @@ TEST(InclusiveScanSplit, CubTest) { TEST(ExclusiveScanSplit, CubTest) { if (!at::cuda::is_available()) return; - at::globalContext().lazyInitCUDA(); // This is required to use PyTorch's caching allocator. + at::globalContext().lazyInitDevice( + c10::DeviceType::CUDA); // This is required to use PyTorch's caching + // allocator. int *output2; cudaMallocManaged(&output2, sizeof(int) * 10); diff --git a/aten/src/ATen/xpu/detail/XPUHooks.cpp b/aten/src/ATen/xpu/detail/XPUHooks.cpp index d9d0f06c0d804..05d4482fe979b 100644 --- a/aten/src/ATen/xpu/detail/XPUHooks.cpp +++ b/aten/src/ATen/xpu/detail/XPUHooks.cpp @@ -9,7 +9,7 @@ namespace at::xpu::detail { -void XPUHooks::initXPU() const { +void XPUHooks::init() const { C10_LOG_API_USAGE_ONCE("aten.init.xpu"); const auto device_count = c10::xpu::device_count_ensure_non_zero(); c10::xpu::XPUCachingAllocator::init(device_count); diff --git a/aten/src/ATen/xpu/detail/XPUHooks.h b/aten/src/ATen/xpu/detail/XPUHooks.h index 2f2b2b70e7a93..6c1c064bae80e 100644 --- a/aten/src/ATen/xpu/detail/XPUHooks.h +++ b/aten/src/ATen/xpu/detail/XPUHooks.h @@ -7,7 +7,7 @@ namespace at::xpu::detail { // The real implementation of XPUHooksInterface struct XPUHooks : public at::XPUHooksInterface { XPUHooks(at::XPUHooksArgs) {} - void initXPU() const override; + void init() const override; bool hasXPU() const override; std::string showConfig() const override; int32_t getGlobalIdxFromDevice(const at::Device& device) const override; diff --git a/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp b/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp index a5c775b0086ef..629a196280391 100644 --- a/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp +++ b/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp @@ -65,7 +65,7 @@ class AsyncInputIsOutputTest : public AsyncTest { numTensors_(numTensors), numDevices_(cudaNumDevices()) { // Allocate inputs on available devices in a round robin fashion. - ::at::globalContext().lazyInitCUDA(); + ::at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); inputs_.resize(numTensors_); for (const auto i : c10::irange(numTensors_)) { inputs_[i] = at::empty( diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index ae6bf94cdd166..fa586e74825f7 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -75,7 +75,7 @@ class NCCLTest : public NCCLTestBase { int inputDim = 3) : NCCLTestBase(path, pgTimeout), rank_(rank), worldSize_(worldSize) { // Each device has a single tensor to perf the NCCL op - ::at::globalContext().lazyInitCUDA(); + ::at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); tensors_.resize(numDevices_); inputs_.resize(numDevices_); outputs_.resize(numDevices_); diff --git a/test/cpp_extensions/mtia_extension.cpp b/test/cpp_extensions/mtia_extension.cpp index fdbfcaa26a27e..257ecf9cc91f8 100644 --- a/test/cpp_extensions/mtia_extension.cpp +++ b/test/cpp_extensions/mtia_extension.cpp @@ -139,7 +139,7 @@ struct MTIAGuardImpl final : public c10::impl::DeviceGuardImplInterface { struct MTIAHooks : public at::MTIAHooksInterface { explicit MTIAHooks(at::MTIAHooksArgs) {} - void initMTIA() const override {} + void init() const override {} bool hasMTIA() const override { return true; diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index cbda6552fe7a6..bc7165c2236f7 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -43,7 +43,7 @@ std::vector allCPUTypes() { } std::vector allCUDATypes() { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); return allTypesForBackends({Backend::CUDA, Backend::SparseCUDA}); } @@ -52,7 +52,7 @@ std::vector allXPUTypes() { } std::vector allPrivateUser1Types() { - at::globalContext().lazyInitPrivateUse1(); + at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1); return allTypesForBackends( {Backend::PrivateUse1, Backend::SparsePrivateUse1}); } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index e456d73bf8c87..ae1f20fac118d 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -890,7 +890,7 @@ PyObject* THCPModule_attachOutOfMemoryObserver( } Py_XDECREF(result); }; - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); c10::cuda::CUDACachingAllocator::attachOutOfMemoryObserver(std::move(obs)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -1425,7 +1425,7 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level poison_fork(); - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda")); if (!m) diff --git a/torch/csrc/cuda/memory_snapshot.cpp b/torch/csrc/cuda/memory_snapshot.cpp index 76ff111936edf..05da63b5bbbc9 100644 --- a/torch/csrc/cuda/memory_snapshot.cpp +++ b/torch/csrc/cuda/memory_snapshot.cpp @@ -138,7 +138,7 @@ void _record_memory_history( } else if (record_context) { when = c10::cuda::CUDACachingAllocator::RecordContext::STATE; } - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); _initRecordAnnotations(); c10::cuda::CUDACachingAllocator::recordHistory( enabled, recorder, trace_alloc_max_entries, when); @@ -189,7 +189,7 @@ void _record_memory_history( when = c10::cuda::CUDACachingAllocator::RecordContext::STATE; } } - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); _initRecordAnnotations(); c10::cuda::CUDACachingAllocator::recordHistory( enabled.has_value(), recorder, max_entries, when); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 2d30ff2104b29..7e01330c02e10 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1032,9 +1032,10 @@ ProcessGroupNCCL::ProcessGroupNCCL( // SEGMENT_FREE action occurs. // We attach hooks only once at the first PG creation. // Attaching hooks fails if CUDACachingAllocator is not initialized, so - // lazyInitCUDA is called (and is a no-op if CUDA is already initialized). + // Init for CUDA is called (and is a no-op if CUDA is already + // initialized). if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( &cacheAllocatorRegisterHook); c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( diff --git a/torch/csrc/jit/mobile/register_ops_common_utils.h b/torch/csrc/jit/mobile/register_ops_common_utils.h index 344b4dd25b858..4406cd5350f61 100644 --- a/torch/csrc/jit/mobile/register_ops_common_utils.h +++ b/torch/csrc/jit/mobile/register_ops_common_utils.h @@ -21,7 +21,7 @@ static C10_UNUSED at::Tensor to_dispatch( bool non_blocking, bool copy) { if (device && device->is_cuda()) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); } if (!device && !scalarType && !copy) { return self; diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index 506ad0a0ee466..37624b3737d67 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -39,7 +39,7 @@ void initModule(PyObject* module) { m.def("_mtia_init", []() { TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level poison_fork(); - at::globalContext().lazyInitMTIA(); + at::globalContext().lazyInitDevice(c10::DeviceType::MTIA); }); m.def("_mtia_isBuilt", []() { diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 6e6c9a4564b65..f07101231ae05 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -363,7 +363,7 @@ static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level poison_fork(); - at::globalContext().lazyInitXPU(); + at::globalContext().lazyInitDevice(c10::DeviceType::XPU); auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu")); if (!m) diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index 091bec237238e..cb7dc00a60b85 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -515,9 +515,7 @@ def generate_defn(cpp_sig: CppSignature) -> str: # CUDA requires special handling if is_cuda_dispatch_key(self.backend_index.dispatch_key): - device_guard = ( - f"globalContext().lazyInitCUDA();\n{device_guard}" - ) + device_guard = f"globalContext().lazyInitDevice(c10::DeviceType::CUDA);\n{device_guard}" else: # kernel is operating on existing tensors