Skip to content

Commit

Permalink
cpu: aarch64: add sbgemm (fp32 input and bf16 weights) inner product op
Browse files Browse the repository at this point in the history
  • Loading branch information
snadampal authored and vpirogov committed Mar 13, 2024
1 parent 8aacc8f commit 214fb9e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/cpu/aarch64/acl_inner_product.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,16 @@ struct acl_inner_product_fwd_t : public primitive_t {
const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef)
&& attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops, f32);
const bool is_fp32_bf16_ok
= expect_data_types(f32, bf16, f32, f32, undef)
&& attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops, f32);
const bool is_weights_md_format_ok
= utils::one_of(weights_format_kind_received,
format_kind::any, format_kind::blocked);
const bool ok = is_fwd() && !has_zero_dim_memory()
&& utils::one_of(true, is_fp16_ok, is_fp32_ok)
&& utils::one_of(
true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok)
&& is_weights_md_format_ok
&& set_default_params(true) == status::success;

Expand Down
9 changes: 9 additions & 0 deletions src/cpu/cpu_inner_product_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
CPU_INSTANCE(ref_inner_product_fwd_t)
nullptr,
}},
/* With graph compilation, we are able to reorder and pre-pack the weights during the model load
* and compilation phase itself so that redundant and on-the-fly reorders can be avoided.
* This primitive definition is to support gemm fastmath mode for the compile scenario where src is
* in fp32 and weights are in bf16
*/
{{forward, f32, bf16, f32}, {
CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t)
nullptr,
}},
{{backward_data, f32, f32, f32}, REG_BWD_PK({
CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32
CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>)
Expand Down

0 comments on commit 214fb9e

Please sign in to comment.