diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp index af4bc39a867..5957311a9ca 100644 --- a/src/cpu/aarch64/acl_inner_product.hpp +++ b/src/cpu/aarch64/acl_inner_product.hpp @@ -101,11 +101,16 @@ struct acl_inner_product_fwd_t : public primitive_t { const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) && attr()->has_default_values( primitive_attr_t::skip_mask_t::post_ops, f32); + const bool is_fp32_bf16_ok + = expect_data_types(f32, bf16, f32, f32, undef) + && attr()->has_default_values( + primitive_attr_t::skip_mask_t::post_ops, f32); const bool is_weights_md_format_ok = utils::one_of(weights_format_kind_received, format_kind::any, format_kind::blocked); const bool ok = is_fwd() && !has_zero_dim_memory() - && utils::one_of(true, is_fp16_ok, is_fp32_ok) + && utils::one_of( + true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok) && is_weights_md_format_ok && set_default_params(true) == status::success; diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp index fdd7b17769f..1f595473047 100644 --- a/src/cpu/cpu_inner_product_list.cpp +++ b/src/cpu/cpu_inner_product_list.cpp @@ -83,6 +83,15 @@ const std::map> &impl_list_map() CPU_INSTANCE(ref_inner_product_fwd_t) nullptr, }}, + /* With graph compilation, we are able to reorder and pre-pack the weights during the model load + * and compilation phase itself so that redundant and on-the-fly reorders can be avoided. + * This primitive definition is to support gemm fastmath mode for the compile scenario where src is + * in fp32 and weights are in bf16 + */ + {{forward, f32, bf16, f32}, { + CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t) + nullptr, + }}, {{backward_data, f32, f32, f32}, REG_BWD_PK({ CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t) // bf32 CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t)