Skip to content

Commit

Permalink
gpu: jit: gemm: enable hf8
Browse files Browse the repository at this point in the history
  • Loading branch information
kealan-barbieri authored and karturov committed Apr 10, 2024
1 parent ad94382 commit c3972ef
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 77 deletions.
3 changes: 2 additions & 1 deletion src/gpu/compute/kernel_arg_list.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2023 Intel Corporation
* Copyright 2019-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,6 +44,7 @@ enum class kernel_arg_kind_t {
enum class scalar_type_t {
undef,
_char,
_hfloat8,
_bfloat8,
_bfloat16,
_float,
Expand Down
7 changes: 6 additions & 1 deletion src/gpu/jit/codegen/ngen_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ namespace impl {
namespace gpu {
namespace jit {

// placeholder for data type unimplemented in HW.
constexpr ngen::DataType ngen_hf8() {
return static_cast<ngen::DataType>(0x0D);
}

template <typename T>
T to_cpp(const ngen::Immediate &imm) {
auto u64 = uint64_t(imm);
Expand Down Expand Up @@ -58,7 +63,6 @@ inline ngen::DataType to_ngen(const type_t &type) {
CASE(bf16, bf);
CASE(f16, hf);
CASE(bf8, bf8);
CASE(hf8, hf8);
CASE(tf32, tf32);
CASE(f32, f);
CASE(f64, df);
Expand All @@ -72,6 +76,7 @@ inline ngen::DataType to_ngen(const type_t &type) {
CASE(u8, ub);

if (type == type_t::byte_ptr()) return ngen::DataType::uq;
if (type == type_kind_t::hf8) return ngen_hf8();

#undef CASE
ir_error_not_expected();
Expand Down
4 changes: 2 additions & 2 deletions src/gpu/jit/codegen/reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
bool dst_q = ngen_is_qw(dst_type);
bool dst_f = (dst_type == ngen::DataType::f);
bool dst_bf8 = (dst_type == ngen::DataType::bf8);
bool dst_hf8 = (dst_type == ngen::DataType::hf8);
bool dst_hf8 = (dst_type == ngen_hf8());
bool dst_hf = (dst_type == ngen::DataType::hf);
bool dst_bf = (dst_type == ngen::DataType::bf);
bool dst_df = (dst_type == ngen::DataType::df);
Expand All @@ -276,7 +276,7 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
bool src_hf = (src_type == ngen::DataType::hf);
bool src_bf = (src_type == ngen::DataType::bf);
bool src_bf8 = (src_type == ngen::DataType::bf8);
bool src_hf8 = (src_type == ngen::DataType::hf8);
bool src_hf8 = (src_type == ngen_hf8());
bool src_df = (src_type == ngen::DataType::df);
bool src_xf = src_bf || src_f || src_hf || src_df;
bool f_to_xf = (src_f && (dst_bf || dst_hf));
Expand Down
15 changes: 9 additions & 6 deletions src/gpu/jit/gemm/gen_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,24 @@ struct gen_gemm_t : public gpu_gemm_t {
ok = ok && d->b_type() == bf16
&& utils::one_of(d->c_type(), bf16, f32)
&& utils::one_of(d->acc_type, bf16, f32);
} else if (!wei_decomp) {
ok = ok && utils::one_of(d->a_type(), f32, f16, f8_e5m2)
} else if (!wei_decomp_) {
ok = ok
&& utils::one_of(
d->a_type(), f32, f16, f8_e5m2, f8_e4m3)
&& d->b_type() == d->a_type()
&& utils::one_of(d->acc_type, d->a_type(), f32)
&& IMPLICATION(utils::one_of(f8_e5m2, d->a_type(),
d->b_type(), d->c_type()),
&& IMPLICATION(
utils::one_of(f8_e5m2, f8_e4m3, d->a_type(),
d->b_type(), d->c_type()),
arch_ >= arch_t::xe_hpc);
}

ok = ok && !has_blocks() && batch_dims() <= 2
&& !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(),
d->k(), d->lda(), d->ldb(), d->ldc(), d->batch())
&& IMPLICATION(with_bias(),
utils::one_of(
d->bias_type(), f32, bf16, f16, f8_e5m2)
utils::one_of(d->bias_type(), f32, bf16, f16,
f8_e5m2, f8_e4m3)
&& (d->bias_desc.ndims <= 3)
&& utils::one_of(bias_cmask(), 0, 1, 2, 3))
&& compute_engine->mayiuse_ngen_kernels()
Expand Down
3 changes: 1 addition & 2 deletions src/gpu/jit/gemm/gen_gemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ status_t gen_gemm_kernel_desc_t::transfer_post_ops(const post_ops_t &post_ops,
problem_.binaryTrans = {};

if (problem_.Ta == Type::f16) problem_.Ts = Type::f32;
if (problem_.Ta == Type::bf8 || problem_.Tb == Type::bf8)
problem_.Ts = Type::f32;
if (problem_.Ta.isF8() || problem_.Tb.isF8()) problem_.Ts = Type::f32;

for (int i = 0; i < po_count; i++) {
const auto &entry = post_ops.entry_[i];
Expand Down
2 changes: 2 additions & 0 deletions src/gpu/jit/gemm/gen_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct gen_gemm_kernel_desc_t {
case Type::s64: return compute::scalar_type_t::_long;
case Type::u64: return compute::scalar_type_t::_ulong;
case Type::bf8: return compute::scalar_type_t::_bfloat8;
case Type::hf8: return compute::scalar_type_t::_hfloat8;
case Type::bf16: return compute::scalar_type_t::_bfloat16;
case Type::f16: return compute::scalar_type_t::_half;
case Type::f32: return compute::scalar_type_t::_float;
Expand Down Expand Up @@ -86,6 +87,7 @@ struct gen_gemm_kernel_desc_t {
case data_type::f16: return Type::f16;
case data_type::bf16: return Type::bf16;
case data_type::f8_e5m2: return Type::bf8;
case data_type::f8_e4m3: return Type::hf8;
case data_type::s32: return Type::s32;
case data_type::u8: return Type::u8;
case data_type::s8: return Type::s8;
Expand Down
Loading

0 comments on commit c3972ef

Please sign in to comment.