Skip to content

Commit

Permalink
x64: dw_conv: int8: fix postop ngroup tail processing
Browse files Browse the repository at this point in the history
  • Loading branch information
msotoflo authored and vpirogov committed Mar 1, 2024
1 parent 4efc0ad commit 0c922e0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
13 changes: 8 additions & 5 deletions src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2023 Intel Corporation
* Copyright 2016-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,10 +65,13 @@ _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::_jit_avx512_core_x8s8s32x_fwd_kernel(
static constexpr bool preserve_gpr = true;
static constexpr bool preserve_vmm = false;
static constexpr size_t helper_vmm_idx = 31;
const size_t oc_block_tail = jcp.oc_block % isa_simd_width_;
const size_t tail_size = oc_block_tail
? oc_block_tail
: jcp.oc_without_padding % isa_simd_width_;
const size_t block_tail
= (jcp.is_depthwise ? jcp.ch_block : jcp.oc_block)
% isa_simd_width_;
const size_t tail_size = block_tail
? block_tail
: (jcp.is_depthwise ? jcp.ngroups : jcp.oc_without_padding)
% isa_simd_width_;
static constexpr bool use_exact_tail_scalar_bcast = false;

const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
Expand Down
13 changes: 8 additions & 5 deletions src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2023 Intel Corporation
* Copyright 2019-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,10 +64,13 @@ _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::_jit_uni_x8s8s32x_fwd_kernel(
static constexpr bool preserve_gpr = true;
static constexpr bool preserve_vmm = false;
static constexpr size_t helper_vmm_idx = 15;
const size_t oc_block_tail = jcp.oc_block % isa_simd_width_;
const size_t tail_size = oc_block_tail
? oc_block_tail
: jcp.oc_without_padding % isa_simd_width_;
const size_t block_tail
= (jcp.is_depthwise ? jcp.ch_block : jcp.oc_block)
% isa_simd_width_;
const size_t tail_size = block_tail
? block_tail
: (jcp.is_depthwise ? jcp.ngroups : jcp.oc_without_padding)
% isa_simd_width_;

const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
r13, r14, r15, preserve_gpr, preserve_vmm,
Expand Down

0 comments on commit 0c922e0

Please sign in to comment.