Skip to content

Commit

Permalink
Improve performance of pointwise/reduction kernels when using NHWC la…
Browse files Browse the repository at this point in the history
…youts
  • Loading branch information
pfultz2 committed Jul 13, 2023
1 parent 4edf119 commit 24359d3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/include/migraphx/permutation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ MIGRAPHX_EXPORT std::vector<int64_t> invert_permutation(const std::vector<int64_
MIGRAPHX_EXPORT std::vector<int64_t> find_permutation(const shape& s);
MIGRAPHX_EXPORT std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);

/// Normalize the shapes so the order of dimensions will be in the order it is
/// in memory as much as possible.
MIGRAPHX_EXPORT std::vector<shape> normalize_permutation(const std::vector<shape>& shapes);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

Expand Down
10 changes: 10 additions & 0 deletions src/permutation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
return it->first;
}

std::vector<shape> normalize_permutation(const std::vector<shape>& shapes)
{
auto result = shapes;
auto perm = find_permutation(shapes);
std::transform(result.begin(), result.end(), result.begin(), [&](auto s) {
return reorder_shape(s, perm);
});
return result;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
2 changes: 1 addition & 1 deletion src/targets/gpu/jit/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.virtual_inputs = reduce_dims(normalize_permutation(inputs));
options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(ctx, axis, options.virtual_inputs);
Expand Down
6 changes: 3 additions & 3 deletions src/targets/gpu/jit/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
std::fill(lens.begin(), lens.end(), 1);
for(const auto& axis : axes)
lens[axis] = s.lens()[axis];
return shape{s.type(), lens};
return s.with_lens(lens);
}

template <class T>
Expand All @@ -93,7 +93,7 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes)
auto lens = s.lens();
for(const auto& axis : axes)
lens[axis] = 1;
return shape{s.type(), lens};
return s.with_lens(lens);
}

template <class ReduceLens>
Expand Down Expand Up @@ -228,7 +228,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto virtual_inputs = inputs;
virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes));
virtual_inputs.push_back(get_output_shape(inputs.front(), axes));
virtual_inputs = reduce_dims(virtual_inputs);
virtual_inputs = reduce_dims(normalize_permutation(virtual_inputs));
auto reduce_output_shape = virtual_inputs.back();
virtual_inputs.pop_back();
auto reduction_shape = virtual_inputs.back();
Expand Down

0 comments on commit 24359d3

Please sign in to comment.