Skip to content

Commit

Permalink
aarch64: shuffle: fix segv for bf16 cases
Browse files Browse the repository at this point in the history
  • Loading branch information
kawakami-k authored and vpirogov committed Jul 17, 2024
1 parent 0013e8c commit 9116681
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 20 deletions.
10 changes: 7 additions & 3 deletions src/cpu/aarch64/shuffle/jit_uni_shuffle.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2022 FUJITSU LIMITED
* Copyright 2020-2024 Intel Corporation
* Copyright 2022-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,6 +34,7 @@ template <cpu_isa_t isa>
status_t jit_uni_shuffle_t<isa>::pd_t::init(engine_t *engine) {
using namespace format_tag;
using namespace data_type;
using namespace types;

const memory_desc_wrapper src_d(is_fwd() ? src_md() : diff_src_md());
const memory_desc_wrapper dst_d(is_fwd() ? dst_md() : diff_dst_md());
Expand All @@ -58,7 +59,10 @@ status_t jit_uni_shuffle_t<isa>::pd_t::init(engine_t *engine) {
if (blocked_format == format_tag::undef) return status::unimplemented;

conf_.blk_size = src_d.blocking_desc().strides[ndims() - 1];
conf_.simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
/* Because "ST1H { <Zt>.S }, <Pg>, [<Xn|SP>, <Zm>.S, UXTW #1]" is used
to gather data for bf16, simd_w must be calculated
with sizeof(uint32_t). */
conf_.simd_w = cpu_isa_traits<isa>::vlen / sizeof(uint32_t);

const bool has_spatial = utils::one_of(ndims(), 3, 4, 5);
const dim_t HW = H() * W();
Expand Down
46 changes: 29 additions & 17 deletions src/cpu/aarch64/shuffle/jit_uni_shuffle_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2022 FUJITSU LIMITED
* Copyright 2021-2024 Intel Corporation
* Copyright 2022-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,9 +47,12 @@ jit_uni_shuffle_kernel_t<isa>::jit_uni_shuffle_kernel_t(
template <cpu_isa_t isa>
void jit_uni_shuffle_kernel_t<isa>::prepare_mask() {
using namespace data_type;
using namespace types;
if (conf_.simd_tail > 0) {
assert(utils::one_of(conf_.data_type, f32, s32));
assert(conf_.simd_tail < isa_sveLen / sizeof(float));
/* Because "ST1H { <Zt>.S }, <Pg>, [<Xn|SP>, <Zm>.S, UXTW #1]" is used
to gather data for bf16, simd_tail must be evaluated
with sizeof(unsigned). */
assert(conf_.simd_tail < isa_sveLen / sizeof(uint32_t));
index(vmm_tmp_.s, 0, 1);
cmplt(k_tail_mask_.s, P_ALL_ONE / T_z, vmm_tmp_.s, conf_.simd_tail);
}
Expand All @@ -68,13 +71,17 @@ void jit_uni_shuffle_kernel_t<asimd>::prepare_mask() {}
template <cpu_isa_t isa>
void jit_uni_shuffle_kernel_t<isa>::gather_data(const XReg &reg_src_addr,
const int indices_idx, const int data_idx, const bool is_tail) {
if (conf_.dt_size == sizeof(float)) {
const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_;
using namespace data_type;
const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_;

if (utils::one_of(conf_.data_type, f32, s32)) {
lsr(TRegS(indices_idx), TRegS(indices_idx), 2);
ld1w(TRegS(data_idx), mask / T_z,
ptr(reg_src_addr, TRegS(indices_idx), UXTW, 2));
} else {
assert(!"unsupported emu_gather_data");
} else if (conf_.data_type == bf16) {
lsr(TRegS(indices_idx), TRegS(indices_idx), 1);
ld1h(TRegS(data_idx), mask / T_z,
ptr(reg_src_addr, TRegS(indices_idx), UXTW, 1));
}
}

Expand All @@ -97,21 +104,26 @@ void jit_uni_shuffle_kernel_t<asimd>::gather_data(const XReg &addr,
template <cpu_isa_t isa>
void jit_uni_shuffle_kernel_t<isa>::store_data(const int data_idx,
const XReg &reg_dst_addr, const int offset, const bool is_tail) {
using namespace data_type;
const auto extend_for_padding
= is_tail && padding_size_ + conf_.simd_tail >= conf_.simd_w;
const PReg &mask = is_tail ? k_tail_mask_ : P_ALL_ONE;

add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);

if (extend_for_padding) {
sel(vmm_tmp_.s, k_tail_mask_, TRegS(data_idx), vmm_zero_.s);
add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
st1w(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR));
if (utils::one_of(conf_.data_type, f32, s32))
st1w(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR));
else // bf16
st1h(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR));
} else {
if (is_tail) {
add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
st1w(TRegS(data_idx), k_tail_mask_, ptr(X_DEFAULT_ADDR));
} else {
add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
st1w(TRegS(data_idx), P_ALL_ONE, ptr(X_DEFAULT_ADDR));
}
if (utils::one_of(conf_.data_type, f32, s32))
st1w(TRegS(data_idx), mask, ptr(X_DEFAULT_ADDR));
else // bf16
st1h(TRegS(data_idx), mask, ptr(X_DEFAULT_ADDR));
}

append_zero_padding(
reg_dst_, isa_sveLen > 128 ? extend_for_padding : false);
}
Expand Down

0 comments on commit 9116681

Please sign in to comment.