Skip to content

Commit

Permalink
Make Context to be Device-agnostic Step by Step (1/N) (pytorch#136519)
Browse files Browse the repository at this point in the history
----

- make init to be device-agnostic and move it to AcceleratorHooksInterface
- refactoring context related to device initialization

Pull Request resolved: pytorch#136519
Approved by: https://github.com/ezyang, https://github.com/EikanWang, https://github.com/guangyey
  • Loading branch information
FFFrog authored and pytorchmergebot committed Oct 13, 2024
1 parent 563e9f9 commit 4a8e493
Show file tree
Hide file tree
Showing 31 changed files with 108 additions and 88 deletions.
77 changes: 33 additions & 44 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class TORCH_API Context {

const Generator& defaultGenerator(Device device) {
c10::DeviceType device_type = device.type();
initCUDAIfNeeded(device_type);
initHIPIfNeeded(device_type);
lazyInitDevice(device_type);

if (device_type == at::kCPU) {
return at::detail::getDefaultCPUGenerator();
} else if (device_type == at::kCUDA) {
Expand All @@ -58,6 +58,7 @@ class TORCH_API Context {
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
}
}

const AcceleratorHooksInterface& getAcceleratorHooksInterface(
std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
c10::DeviceType device_type = opt_device_type.has_value()
Expand All @@ -80,16 +81,17 @@ class TORCH_API Context {
c10::DeviceTypeName(device_type), " device type not an accelerator.");
}
}

Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
initCUDAIfNeeded(device_type);
initHIPIfNeeded(device_type);
initXPUIfNeeded(device_type);
lazyInitDevice(device_type);

if (device_type == at::kCPU) {
return c10::DeviceType::CPU;
} else {
return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data);
}
}

bool isPinnedPtr(
const void* data,
std::optional<c10::DeviceType> device_type = std::nullopt) {
Expand All @@ -102,10 +104,20 @@ class TORCH_API Context {
}
return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data);
}

Allocator* getPinnedMemoryAllocator(
std::optional<c10::DeviceType> device_type = std::nullopt) {
return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator();
}

void lazyInitDevice(c10::DeviceType device_type) {
if (device_type != at::kCPU) {
c10::call_once(init_[static_cast<int8_t>(device_type)], [&] {
getAcceleratorHooksInterface(device_type).init();
});
}
}

static bool hasOpenMP();
static bool hasMKL();
static bool hasLAPACK();
Expand Down Expand Up @@ -158,27 +170,6 @@ class TORCH_API Context {
static bool hasMAIA() {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
}
// defined in header so that getNonVariableType has ability to inline
// call_once check. getNonVariableType is called fairly frequently
void lazyInitCUDA() {
c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
}
void lazyInitHIP() {
c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
}
void lazyInitXPU() {
c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
}
void lazyInitMTIA() {
c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
}
void lazyInitPrivateUse1() {
c10::call_once(thp_init, [&] {
if (isPrivateUse1HooksRegistered()) {
at::detail::getPrivateUse1Hooks().initPrivateUse1();
}
});
}
static const at::cuda::NVRTC& getNVRTC() {
return detail::getCUDAHooks().nvrtc();
}
Expand Down Expand Up @@ -353,28 +344,26 @@ class TORCH_API Context {
bool allowFP16ReductionCPU() const;
void setAllowFP16ReductionCPU(bool);

private:
void initCUDAIfNeeded(c10::DeviceType p) {
if (p == c10::DeviceType::CUDA) {
lazyInitCUDA();
}
// Preserved for BC
void lazyInitCUDA() {
lazyInitDevice(at::kCUDA);
}
void initHIPIfNeeded(c10::DeviceType p) {
if (p == c10::DeviceType::HIP) {
lazyInitHIP();
}
void lazyInitHIP() {
lazyInitDevice(at::kHIP);
}
void initXPUIfNeeded(c10::DeviceType p) {
if (p == c10::DeviceType::XPU) {
lazyInitXPU();
}
void lazyInitXPU() {
lazyInitDevice(at::kXPU);
}
void lazyInitMTIA() {
lazyInitDevice(at::kMTIA);
}
void lazyInitPrivateUse1() {
lazyInitDevice(at::kPrivateUse1);
}

private:
static bool checkCuBLASConfigDeterministic();
c10::once_flag thc_init;
c10::once_flag thh_init;
c10::once_flag thx_init;
c10::once_flag th_mtia_init;
c10::once_flag thp_init;
std::array<c10::once_flag, at::COMPILE_TIME_MAX_DEVICE_TYPES> init_;
bool enabled_cudnn = true;
bool deterministic_cudnn = false;
bool deterministic_mkldnn = false;
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/EmptyTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ TensorBase empty_cuda(
ScalarType dtype,
std::optional<Device> device_opt,
std::optional<c10::MemoryFormat> memory_format_opt) {
at::globalContext().lazyInitCUDA();
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
const auto device = device_or_default(device_opt);
TORCH_INTERNAL_ASSERT(device.is_cuda());
const DeviceGuard device_guard(device);
Expand Down Expand Up @@ -50,7 +50,7 @@ TensorBase empty_strided_cuda(
IntArrayRef stride,
ScalarType dtype,
std::optional<Device> device_opt) {
at::globalContext().lazyInitCUDA();
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
const auto device = device_or_default(device_opt);
TORCH_INTERNAL_ASSERT(device.is_cuda());
const DeviceGuard device_guard(device);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/PeerToPeerAccess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void init_p2p_access_cache(int64_t num_devices) {
} // namespace detail

bool get_p2p_access(int dev, int dev_to_access) {
at::globalContext().lazyInitCUDA();
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);

TORCH_CHECK(dev >= 0 || dev < num_devices_,
dev, " is not a device");
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct _Initializer {
// NB: deleter is dynamic, because we need it to live in a separate
// compilation unit (alt is to have another method in hooks, but
// let's not if we don't need to!)
void CUDAHooks::initCUDA() const {
void CUDAHooks::init() const {
C10_LOG_API_USAGE_ONCE("aten.init.cuda");
// Force the update to enable unit testing. This code get executed before unit tests
// have a chance to enable vitals.
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/detail/CUDAHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
// The real implementation of CUDAHooksInterface
struct CUDAHooks : public at::CUDAHooksInterface {
CUDAHooks(at::CUDAHooksArgs) {}
void initCUDA() const override;
void init() const override;
Device getDeviceFromPtr(void* data) const override;
bool isPinnedPtr(const void* data) const override;
const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/detail/AcceleratorHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ struct TORCH_API AcceleratorHooksInterface {
// Whether the device at device_index is fully initialized or not.
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;

virtual void init() const {
TORCH_CHECK(false, "Backend doesn`t support init()");
}

virtual DeviceIndex deviceCount() const {
return 0;
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/detail/CUDAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
~CUDAHooksInterface() override = default;

// Initialize THCState and, transitively, the CUDA state
virtual void initCUDA() const {
void init() const override {
TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
}

Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/detail/HIPHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface {
// squelch -Werror=non-virtual-dtor
~HIPHooksInterface() override = default;

// Initialize the HIP library state
virtual void initHIP() const {
AT_ERROR("Cannot initialize HIP without ATen_hip library.");
void init() const override {
TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library.");
}

virtual std::unique_ptr<c10::GeneratorImpl> initHIPGenerator(Context*) const {
Expand Down
15 changes: 13 additions & 2 deletions aten/src/ATen/detail/IPUHooksInterface.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
#pragma once

#include <ATen/core/Generator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>

#include <c10/core/Allocator.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>

namespace at {

struct TORCH_API IPUHooksInterface {
virtual ~IPUHooksInterface() = default;
struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface {
~IPUHooksInterface() override = default;

void init() const override {
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
}

bool hasPrimaryContext(DeviceIndex device_index) const override {
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
return false;
}

virtual const Generator& getDefaultIPUGenerator(
DeviceIndex device_index [[maybe_unused]] = -1) const {
Expand Down
15 changes: 13 additions & 2 deletions aten/src/ATen/detail/MAIAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>

#include <ATen/detail/AcceleratorHooksInterface.h>

// NB: Class must live in `at` due to limitations of Registry.h.
namespace at {

struct TORCH_API MAIAHooksInterface {
struct TORCH_API MAIAHooksInterface : AcceleratorHooksInterface {
// This should never actually be implemented, but it is used to
// squelch -Werror=non-virtual-dtor
virtual ~MAIAHooksInterface() = default;
~MAIAHooksInterface() override = default;

void init() const override {
TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library.");
}

bool hasPrimaryContext(DeviceIndex device_index) const override {
TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library.");
return false;
}

virtual std::string showConfig() const {
TORCH_CHECK(false, "Cannot query detailed MAIA version information.");
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/detail/MPSHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
~MPSHooksInterface() override = default;

// Initialize the MPS library state
virtual void initMPS() const {
void init() const override {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual bool hasMPS() const {
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/detail/MTIAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {

~MTIAHooksInterface() override = default;

virtual void initMTIA() const {
void init() const override {
// Avoid logging here, since MTIA needs init devices first then it will know
// how many devices are available. Make it as no-op if mtia extension is not
// dynamically loaded.
Expand Down Expand Up @@ -109,6 +109,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
FAIL_MTIAHOOKS_FUNC(__func__);
return nullptr;
}

// Perserved for BC
virtual void initMTIA() const {
return;
}
};

struct TORCH_API MTIAHooksArgs {};
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/detail/PrivateUse1HooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");
}

virtual void initPrivateUse1() const {}
void init() const override {}
virtual void resizePrivateUse1Bytes(
const c10::Storage& storage,
size_t newsize) const {
Expand Down
6 changes: 2 additions & 4 deletions aten/src/ATen/detail/XPUHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ namespace at {
struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
~XPUHooksInterface() override = default;

virtual void initXPU() const {
TORCH_CHECK(
false,
"Cannot initialize XPU without ATen_xpu library.");
void init() const override {
TORCH_CHECK(false, "Cannot initialize XPU without ATen_xpu library.");
}

virtual bool hasXPU() const {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace at::mps {
// The real implementation of MPSHooksInterface
struct MPSHooks : public at::MPSHooksInterface {
MPSHooks(at::MPSHooksArgs) {}
void initMPS() const override;
void init() const override;

// MPSDevice interface
bool hasMPS() const override;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSHooks.mm
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace at::mps {

void MPSHooks::initMPS() const {
void MPSHooks::init() const {
C10_LOG_API_USAGE_ONCE("aten.init.mps");
// TODO: initialize MPS devices and streams here
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes) {
c10::cuda::CUDAGuard guard(device.index());
at::DataPtr data = allocator->allocate(size_bytes);
if (storage->data_ptr()) {
at::globalContext().lazyInitCUDA();
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);

C10_CUDA_CHECK(
cudaMemcpyAsync(
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/test/cuda_cub_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ __managed__ int input[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };

TEST(InclusiveScanSplit, CubTest) {
if (!at::cuda::is_available()) return;
at::globalContext().lazyInitCUDA(); // This is required to use PyTorch's caching allocator.
at::globalContext().lazyInitDevice(
c10::DeviceType::CUDA); // This is required to use PyTorch's caching
// allocator.

int *output1;
cudaMallocManaged(&output1, sizeof(int) * 10);
Expand All @@ -162,7 +164,9 @@ TEST(InclusiveScanSplit, CubTest) {

TEST(ExclusiveScanSplit, CubTest) {
if (!at::cuda::is_available()) return;
at::globalContext().lazyInitCUDA(); // This is required to use PyTorch's caching allocator.
at::globalContext().lazyInitDevice(
c10::DeviceType::CUDA); // This is required to use PyTorch's caching
// allocator.

int *output2;
cudaMallocManaged(&output2, sizeof(int) * 10);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/xpu/detail/XPUHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace at::xpu::detail {

void XPUHooks::initXPU() const {
void XPUHooks::init() const {
C10_LOG_API_USAGE_ONCE("aten.init.xpu");
const auto device_count = c10::xpu::device_count_ensure_non_zero();
c10::xpu::XPUCachingAllocator::init(device_count);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/xpu/detail/XPUHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace at::xpu::detail {
// The real implementation of XPUHooksInterface
struct XPUHooks : public at::XPUHooksInterface {
XPUHooks(at::XPUHooksArgs) {}
void initXPU() const override;
void init() const override;
bool hasXPU() const override;
std::string showConfig() const override;
int32_t getGlobalIdxFromDevice(const at::Device& device) const override;
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class AsyncInputIsOutputTest : public AsyncTest {
numTensors_(numTensors),
numDevices_(cudaNumDevices()) {
// Allocate inputs on available devices in a round robin fashion.
::at::globalContext().lazyInitCUDA();
::at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
inputs_.resize(numTensors_);
for (const auto i : c10::irange(numTensors_)) {
inputs_[i] = at::empty(
Expand Down
Loading

0 comments on commit 4a8e493

Please sign in to comment.