Skip to content

Commit

Permalink
[GPU] fix permute fy (#24710)
Browse files Browse the repository at this point in the history
### Details:
 - when tile_size < vector_size then JTIMES == 0

### Tickets:
 - 142398
  • Loading branch information
michal-miotk committed May 28, 2024
1 parent 1981b3e commit 185783b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ size_t GetTileWidth(const permute_params& params) {

// i64 only supports tile size 4
if ((input_type == Datatype::INT64) || (output_type == Datatype::INT64)) {
min_divisor = min_divisor / 2;
min_divisor = min_divisor >= 4 ? min_divisor / 2 : min_divisor;
}
if (input_type == Datatype::F16) {
min_divisor = min_divisor * 2;
Expand All @@ -77,7 +77,7 @@ size_t GetTileWidth(const permute_params& params) {
if (params.inputs[0].X().v == 1) {
return std::min(params.inputs[0].Y().v, min_divisor);
}
return std::min(params.inputs[0].X().v, min_divisor);
return std::min(GetDivisor(params.inputs[0].X().v), min_divisor);
}

size_t GetTileSize(const permute_params& params) {
Expand Down Expand Up @@ -129,8 +129,8 @@ JitConstants PermuteKernel_f_y_axes::GetJitConstants(const permute_params& param
}

const size_t tile_width = GetTileWidth(params);
const size_t vector_size = std::min(tile_width, static_cast<size_t>(4));
const size_t tile_size = GetTileSize(params);
const size_t vector_size = IsSimpleMemCopyOperation(params) ? std::min(tile_width, static_cast<size_t>(4)): std::min(tile_size, static_cast<size_t>(4));
const size_t j_times = IsSimpleMemCopyOperation(params) ? tile_width / vector_size : tile_size / vector_size;
const size_t feature_block_size = GetFeatureBlockSize(params);
jit.AddConstant(MakeJitConstant("BLOCK_SIZE", tile_width));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2278,6 +2278,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_permute_f_y_axes_tile,
{{32, 16, 8, 32}, format::bfyx},
{{32, 8, 16, 32}, format::bfyx},
{{32, 196, 8, 64}, format::bfyx}, // permute_f_y_axes
{{1, 512, 30, 1}, format::bfyx}, // fix for JTIMES=0
{{1, 2, 512, 10}, format::bfyx}, //case trying to set vec size(4) bigger than x divisor(2) in case of f16
}),
TiledPermuteTest::PrintToStringParamName);

Expand Down

0 comments on commit 185783b

Please sign in to comment.