From 05b2ff452ae8585ed1d191d907e7e95c448bc318 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Fri, 16 Aug 2024 11:24:12 -0500 Subject: [PATCH] Fix transposed padding (#3342) --- src/include/migraphx/op/pad.hpp | 3 +-- src/targets/gpu/jit/pad.cpp | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/op/pad.hpp b/src/include/migraphx/op/pad.hpp index c24b83edf83..7fad02b8375 100644 --- a/src/include/migraphx/op/pad.hpp +++ b/src/include/migraphx/op/pad.hpp @@ -79,8 +79,7 @@ struct pad { rdims[i] += pads[i] + pads[i + num_dims]; } - shape s{s0.type(), rdims}; - return s; + return s0.with_lens(rdims); } } diff --git a/src/targets/gpu/jit/pad.cpp b/src/targets/gpu/jit/pad.cpp index 02654dd46b1..48f43359e76 100644 --- a/src/targets/gpu/jit/pad.cpp +++ b/src/targets/gpu/jit/pad.cpp @@ -79,7 +79,7 @@ struct pad_compiler : compiler auto vinputs = inputs; vinputs.push_back(inputs.front().with_lens(offset_lens)); - auto rinputs = reduce_dims(vinputs); + auto rinputs = reduce_dims(normalize_permutation(vinputs)); auto rinput_lens = rinputs.front().lens(); auto roffset_lens = rinputs.back().lens();