From 563e9f99c3de8a24fc740927dc12a0eec7895d8b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 13 Oct 2024 12:12:37 +0000 Subject: [PATCH] Revert "Add device agnostic API for accelerator hooks (#137480)" This reverts commit 858c91c3d8d9a71c66d0357e51a4bd805f95599f. Reverted https://github.com/pytorch/pytorch/pull/137480 on behalf of https://github.com/albanD due to break all builds on trunk ([comment](https://github.com/pytorch/pytorch/pull/137480#issuecomment-2408954802)) --- aten/src/ATen/detail/AcceleratorHooksInterface.h | 12 ++++-------- aten/src/ATen/detail/CUDAHooksInterface.h | 4 ++-- aten/src/ATen/detail/MAIAHooksInterface.h | 2 +- aten/src/ATen/detail/MPSHooksInterface.h | 2 +- aten/src/ATen/detail/MTIAHooksInterface.h | 4 ++-- aten/src/ATen/detail/XPUHooksInterface.h | 4 ++-- aten/src/ATen/mps/MPSHooks.h | 2 +- aten/src/ATen/mps/MPSHooks.mm | 2 +- 8 files changed, 14 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/detail/AcceleratorHooksInterface.h b/aten/src/ATen/detail/AcceleratorHooksInterface.h index a8ab5dd814c64..0f97e03090405 100644 --- a/aten/src/ATen/detail/AcceleratorHooksInterface.h +++ b/aten/src/ATen/detail/AcceleratorHooksInterface.h @@ -29,14 +29,17 @@ struct TORCH_API AcceleratorHooksInterface { virtual DeviceIndex getCurrentDevice() const { TORCH_CHECK(false, "Backend doesn't support getCurrentDevice()"); + return -1; } virtual DeviceIndex exchangeDevice(DeviceIndex device) const { TORCH_CHECK(false, "Backend doesn't support exchangeDevice()"); + return -1; } virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const { TORCH_CHECK(false, "Backend doesn't support maybeExchangeDevice()"); + return -1; } virtual bool isPinnedPtr(const void* data) const { @@ -45,14 +48,7 @@ struct TORCH_API AcceleratorHooksInterface { virtual Allocator* getPinnedMemoryAllocator() const { TORCH_CHECK(false, "Backend doesn't support getPinnedMemoryAllocator()"); - } - - virtual std::string showConfig() const { - TORCH_CHECK(false, "Backend doesn't support showConfig()"); - } - - virtual void deviceSynchronize(DeviceIndex device) const { - TORCH_CHECK(false, "Backend doesn't support deviceSynchronize()"); + return nullptr; } virtual Device getDeviceFromPtr(void* data) const { diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index 58be625f715e2..fe29a2d702b70 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -157,7 +157,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP); } - std::string showConfig() const override { + virtual std::string showConfig() const { TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP); } @@ -192,7 +192,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { } #endif - void deviceSynchronize(DeviceIndex /*device_index*/) const override { + virtual void deviceSynchronize(DeviceIndex /*device_index*/) const { TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP); } }; diff --git a/aten/src/ATen/detail/MAIAHooksInterface.h b/aten/src/ATen/detail/MAIAHooksInterface.h index 38ca58914373e..ad4ef146eccd9 100644 --- a/aten/src/ATen/detail/MAIAHooksInterface.h +++ b/aten/src/ATen/detail/MAIAHooksInterface.h @@ -11,7 +11,7 @@ struct TORCH_API MAIAHooksInterface { // squelch -Werror=non-virtual-dtor virtual ~MAIAHooksInterface() = default; - std::string showConfig() const override { + virtual std::string showConfig() const { TORCH_CHECK(false, "Cannot query detailed MAIA version information."); } }; diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index 480263a5bbd08..180ff68588edd 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -37,7 +37,7 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { virtual Allocator* getMPSDeviceAllocator() const { FAIL_MPSHOOKS_FUNC(__func__); } - void deviceSynchronize(C10_UNUSED DeviceIndex device_index = -1) const override { + virtual void deviceSynchronize() const { FAIL_MPSHOOKS_FUNC(__func__); } virtual void commitStream() const { diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index 6ad53f7b63682..1480436fb4f1d 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -46,11 +46,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { return 0; } - void deviceSynchronize(c10::DeviceIndex device_index) const override { + virtual void deviceSynchronize(c10::DeviceIndex device_index) const { FAIL_MTIAHOOKS_FUNC(__func__); } - std::string showConfig() const override { + virtual std::string showConfig() const { FAIL_MTIAHOOKS_FUNC(__func__); } diff --git a/aten/src/ATen/detail/XPUHooksInterface.h b/aten/src/ATen/detail/XPUHooksInterface.h index 5bf7de75992ea..9d349102d38bd 100644 --- a/aten/src/ATen/detail/XPUHooksInterface.h +++ b/aten/src/ATen/detail/XPUHooksInterface.h @@ -24,7 +24,7 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{ return false; } - std::string showConfig() const override { + virtual std::string showConfig() const { TORCH_CHECK( false, "Cannot query detailed XPU version without ATen_xpu library."); @@ -54,7 +54,7 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{ TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library."); } - void deviceSynchronize(DeviceIndex /*device_index*/) const override { + virtual void deviceSynchronize(DeviceIndex /*device_index*/) const { TORCH_CHECK(false, "Cannot synchronize XPU device without ATen_xpu library."); } diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index d58c463d6f303..4858c0609f56b 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -22,7 +22,7 @@ struct MPSHooks : public at::MPSHooksInterface { const Generator& getDefaultMPSGenerator() const override; // MPSStream interface - void deviceSynchronize(DeviceIndex device_index = -1) const override; + void deviceSynchronize() const override; void commitStream() const override; void* getCommandBuffer() const override; void* getDispatchQueue() const override; diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index c3e9ed8970521..5855e16aca8c9 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -63,7 +63,7 @@ return at::mps::detail::getDefaultMPSGenerator(); } -void MPSHooks::deviceSynchronize([[maybe_unused]] DeviceIndex device_index) const { +void MPSHooks::deviceSynchronize() const { at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT); }