Skip to content

Commit

Permalink
cublaslt autotuning support for TunableOp (pytorch#133896)
Browse files Browse the repository at this point in the history
Adds support for cublaslt autotuning to TunableOp.

Todo:
- [x] Add and test `ScaledGemmTunableOp`
- [x] Benchmarking numbers

Pull Request resolved: pytorch#133896
Approved by: https://github.com/eqy, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
  • Loading branch information
2 people authored and pytorchmergebot committed Oct 11, 2024
1 parent 1358969 commit 19bbbef
Show file tree
Hide file tree
Showing 7 changed files with 912 additions and 143 deletions.
117 changes: 14 additions & 103 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
" but got ", \
X)

namespace {
namespace at::cuda::blas {

static cublasOperation_t _cublasOpFromChar(char op) {
cublasOperation_t _cublasOpFromChar(char op) {
switch (op) {
case 'n':
case 'N':
Expand All @@ -118,7 +118,7 @@ static cublasOperation_t _cublasOpFromChar(char op) {
"_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
}

static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
// Note: leading dimensions generally are checked that they are > 0
// and at least as big the result requires (even if the value won't
// be used).
Expand All @@ -132,7 +132,7 @@ static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
*lda = std::max<int64_t>(m, 1);
}

static void _cublasAdjustLdLevel3(
void _cublasAdjustLdLevel3(
char transa,
char transb,
int64_t m,
Expand Down Expand Up @@ -179,7 +179,7 @@ uint32_t _getAlignment(uintptr_t address) {
}
#endif

static size_t _parseChosenWorkspaceSize() {
size_t _parseChosenWorkspaceSize() {
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
#ifdef USE_ROCM
if (!val) {
Expand All @@ -202,15 +202,11 @@ static size_t _parseChosenWorkspaceSize() {
return workspace_size * 1024;
}

static size_t _getWorkspaceSize() {
size_t _getWorkspaceSize() {
static size_t workspace_size = _parseChosenWorkspaceSize();
return workspace_size;
}

} // anonymous namespace

namespace at::cuda::blas {

/* LEVEL 3 BLAS FUNCTIONS */

#define GEMM_CHECK_ARGVALUES(Dtype) \
Expand All @@ -234,91 +230,6 @@ namespace at::cuda::blas {
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
} while (0)


namespace {
// Following the pattern of CuSparseDescriptor
// Defined here for now because this is the only place cublas_lt interface is
// used but can be moved to a header once cublas_lt interface is used in
// multiple places.
template <typename T, cublasStatus_t (*destructor)(T*)>
struct CuBlasLtDeleter {
void operator()(T* x) {
if (x != nullptr) {
TORCH_CUDABLAS_CHECK(destructor(x));
}
}
};

template <typename T, cublasStatus_t (*destructor)(T*)>
class CuBlasLtDescriptor {
public:
T* descriptor() const {
return descriptor_.get();
}
T* descriptor() {
return descriptor_.get();
}

protected:
std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
};

class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
cublasLtMatmulDescOpaque_t,
&cublasLtMatmulDescDestroy> {
public:
CuBlasLtMatmulDescriptor(
cublasComputeType_t compute_type,
cudaDataType_t scale_type) {
cublasLtMatmulDesc_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
cublasLtMatrixLayoutOpaque_t,
&cublasLtMatrixLayoutDestroy> {
public:
CuBlasLtMatrixLayout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
int64_t ld,
bool t = false) {
cublasLtMatrixLayout_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
cublasLtMatmulPreferenceOpaque_t,
&cublasLtMatmulPreferenceDestroy> {
public:
CuBlasLtMatmulPreference() {
cublasLtMatmulPreference_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};
} // namespace


template <typename Dtype>
inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
cudaDataType_t abcType = CUDA_R_32F;
Expand Down Expand Up @@ -695,19 +606,19 @@ inline void bgemm_tunable(CUDABLAS_BGEMM_ARGTYPES(DType)) {
bool transb_ = ((transb != 'n') && (transb != 'N'));

if (transa_ && transb_) {
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> bgemm{};
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> bgemm{&params};
bgemm(&params);
}
else if (transa_ && !transb_) {
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> bgemm{};
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> bgemm{&params};
bgemm(&params);
}
else if (!transa_ && transb_) {
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> bgemm{};
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> bgemm{&params};
bgemm(&params);
}
else if (!transa_ && !transb_) {
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> bgemm{};
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> bgemm{&params};
bgemm(&params);
}
else {
Expand Down Expand Up @@ -1091,19 +1002,19 @@ inline void gemm_tunable(CUDABLAS_GEMM_ARGTYPES(DType)) {
bool transb_ = ((transb != 'n') && (transb != 'N'));

if (transa_ && transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> gemm{};
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> gemm{&params};
gemm(&params);
}
else if (transa_ && !transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> gemm{};
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> gemm{&params};
gemm(&params);
}
else if (!transa_ && transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> gemm{};
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> gemm{&params};
gemm(&params);
}
else if (!transa_ && !transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> gemm{};
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> gemm{&params};
gemm(&params);
}
else {
Expand Down
99 changes: 99 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,105 @@

namespace at::cuda::blas {

cublasOperation_t _cublasOpFromChar(char op);
void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda);
void _cublasAdjustLdLevel3(
char transa,
char transb,
int64_t m,
int64_t n,
int64_t k,
int64_t* lda,
int64_t* ldb,
int64_t* ldc);
uint32_t _getAlignment(uintptr_t address);
size_t _parseChosenWorkspaceSize();
size_t _getWorkspaceSize();

namespace {
// Following the pattern of CuSparseDescriptor
// Defined here for now because this is the only place cublas_lt interface is
// used but can be moved to a header once cublas_lt interface is used in
// multiple places.
template <typename T, cublasStatus_t (*destructor)(T*)>
struct CuBlasLtDeleter {
void operator()(T* x) {
if (x != nullptr) {
TORCH_CUDABLAS_CHECK(destructor(x));
}
}
};

template <typename T, cublasStatus_t (*destructor)(T*)>
class CuBlasLtDescriptor {
public:
T* descriptor() const {
return descriptor_.get();
}
T* descriptor() {
return descriptor_.get();
}

protected:
std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
};

class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
cublasLtMatmulDescOpaque_t,
&cublasLtMatmulDescDestroy> {
public:
CuBlasLtMatmulDescriptor(
cublasComputeType_t compute_type,
cudaDataType_t scale_type) {
cublasLtMatmulDesc_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
cublasLtMatrixLayoutOpaque_t,
&cublasLtMatrixLayoutDestroy> {
public:
CuBlasLtMatrixLayout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
int64_t ld,
bool t = false) {
cublasLtMatrixLayout_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
cublasLtMatmulPreferenceOpaque_t,
&cublasLtMatmulPreferenceDestroy> {
public:
CuBlasLtMatmulPreference() {
cublasLtMatmulPreference_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

} // namespace

// RAII guard that sets the CuBLAS pointer mode and restores it to
// its previous value when the guard is destroyed
class PointerModeGuard {
Expand Down
Loading

0 comments on commit 19bbbef

Please sign in to comment.