Skip to content

Commit

Permalink
Revert "Add device agnostic API for accelerator hooks (pytorch#137480)"
Browse files Browse the repository at this point in the history
This reverts commit 858c91c.

Reverted pytorch#137480 on behalf of https://github.com/albanD due to break all builds on trunk ([comment](pytorch#137480 (comment)))
  • Loading branch information
pytorchmergebot committed Oct 13, 2024
1 parent 08576b2 commit 563e9f9
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 18 deletions.
12 changes: 4 additions & 8 deletions aten/src/ATen/detail/AcceleratorHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@ struct TORCH_API AcceleratorHooksInterface {

virtual DeviceIndex getCurrentDevice() const {
TORCH_CHECK(false, "Backend doesn't support getCurrentDevice()");
return -1;
}

virtual DeviceIndex exchangeDevice(DeviceIndex device) const {
TORCH_CHECK(false, "Backend doesn't support exchangeDevice()");
return -1;
}

virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const {
TORCH_CHECK(false, "Backend doesn't support maybeExchangeDevice()");
return -1;
}

virtual bool isPinnedPtr(const void* data) const {
Expand All @@ -45,14 +48,7 @@ struct TORCH_API AcceleratorHooksInterface {

virtual Allocator* getPinnedMemoryAllocator() const {
TORCH_CHECK(false, "Backend doesn't support getPinnedMemoryAllocator()");
}

virtual std::string showConfig() const {
TORCH_CHECK(false, "Backend doesn't support showConfig()");
}

virtual void deviceSynchronize(DeviceIndex device) const {
TORCH_CHECK(false, "Backend doesn't support deviceSynchronize()");
return nullptr;
}

virtual Device getDeviceFromPtr(void* data) const {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/detail/CUDAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP);
}

std::string showConfig() const override {
virtual std::string showConfig() const {
TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP);
}

Expand Down Expand Up @@ -192,7 +192,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
}
#endif

void deviceSynchronize(DeviceIndex /*device_index*/) const override {
virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP);
}
};
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/detail/MAIAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct TORCH_API MAIAHooksInterface {
// squelch -Werror=non-virtual-dtor
virtual ~MAIAHooksInterface() = default;

std::string showConfig() const override {
virtual std::string showConfig() const {
TORCH_CHECK(false, "Cannot query detailed MAIA version information.");
}
};
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/detail/MPSHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
virtual Allocator* getMPSDeviceAllocator() const {
FAIL_MPSHOOKS_FUNC(__func__);
}
void deviceSynchronize(C10_UNUSED DeviceIndex device_index = -1) const override {
virtual void deviceSynchronize() const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void commitStream() const {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/detail/MTIAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
return 0;
}

void deviceSynchronize(c10::DeviceIndex device_index) const override {
virtual void deviceSynchronize(c10::DeviceIndex device_index) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}

std::string showConfig() const override {
virtual std::string showConfig() const {
FAIL_MTIAHOOKS_FUNC(__func__);
}

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/detail/XPUHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
return false;
}

std::string showConfig() const override {
virtual std::string showConfig() const {
TORCH_CHECK(
false,
"Cannot query detailed XPU version without ATen_xpu library.");
Expand Down Expand Up @@ -54,7 +54,7 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library.");
}

void deviceSynchronize(DeviceIndex /*device_index*/) const override {
virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
TORCH_CHECK(false, "Cannot synchronize XPU device without ATen_xpu library.");
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct MPSHooks : public at::MPSHooksInterface {
const Generator& getDefaultMPSGenerator() const override;

// MPSStream interface
void deviceSynchronize(DeviceIndex device_index = -1) const override;
void deviceSynchronize() const override;
void commitStream() const override;
void* getCommandBuffer() const override;
void* getDispatchQueue() const override;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSHooks.mm
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
return at::mps::detail::getDefaultMPSGenerator();
}

void MPSHooks::deviceSynchronize([[maybe_unused]] DeviceIndex device_index) const {
void MPSHooks::deviceSynchronize() const {
at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
}

Expand Down

0 comments on commit 563e9f9

Please sign in to comment.