Skip to content

Commit

Permalink
handle review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lakhinderwalia committed Sep 18, 2024
1 parent 28612c4 commit 7c643e6
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/quantize_int4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ static void int4_quantize_module(module& m)
if(not inp->can_eval())
return inp;

std::vector<float> val;
inp->eval().visit([&](auto in_data) { val.assign(in_data.begin(), in_data.end()); });
std::vector<float> val = inp->eval().to_vector<float>();

auto [min, max] = std::minmax_element(val.begin(), val.end());
*min = *min > 0 ? 0 : *min;
*max = *max < 0 ? 0 : *max;
Expand Down
9 changes: 1 addition & 8 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
#include <migraphx/optimize_module.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/promote_literals.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_gelu.hpp>
Expand Down Expand Up @@ -162,20 +161,14 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
normalize_ops{},
dead_code_elimination{},
// 1. For int4 ops: first remove the identity op next to packed int4 data.
eliminate_identity{},
// 2. Next: (pack/unpack)_int4 handling in simplyfy_qdq.
simplify_qdq{},
// 3. Next: pruning of pack_int4 related const branches.
propagate_constant{{}},
// 4. Last for int4: dce.
dead_code_elimination{},
enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{},
// workaround for rocBLAS unsupported error when using uint8 in quant_dot, quant_convolution & pooling
eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_convolution", "quant_dot", "pooling"}},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{},
eliminate_identity{},
eliminate_pad{},
dead_code_elimination{},
insert_pad{{"convolution"}},
Expand Down
8 changes: 7 additions & 1 deletion test/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,13 @@ module {
TEST_CASE(int4_unpack_conv)
{
std::string mlir_output = R"__migraphx__(
module { func.func @mlir_unpack_int4_quant_convolution(%arg0: !migraphx.shaped<2x8x2x1xi8, 16x2x1x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>) -> !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { %0 = migraphx.unpack %arg0 {axis = 3 : i64, isUnsigned = false} : <2x8x2x1xi8, 16x2x1x1> -> <2x8x2x2xi8, 32x4x2x1> %1 = migraphx.quant_convolution %arg1, %0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x2x2xi8, 32x4x2x1> -> <1x2x3x3xi32, 18x9x3x1> return %1 : !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> } }
module {
func.func @mlir_unpack_int4_quant_convolution(%arg0: !migraphx.shaped<2x8x2x1xi8, 16x2x1x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>) -> !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.unpack %arg0 {axis = 3 : i64, isUnsigned = false} : <2x8x2x1xi8, 16x2x1x1> -> <2x8x2x2xi8, 32x4x2x1>
%1 = migraphx.quant_convolution %arg1, %0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x2x2xi8, 32x4x2x1> -> <1x2x3x3xi32, 18x9x3x1>
return %1 : !migraphx.shaped<1x2x3x3xi32, 18x9x3x1>
}
}
)__migraphx__";
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 4, 4}});
Expand Down
9 changes: 6 additions & 3 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5671,7 +5671,8 @@ def int4_const_identity_qdq_test():
outputs=['i_y_zp'],
)

data_values = np.array([[-3, -4, -5, 2], [2, 2, 4, 4], [2, -2, 4, 6], [2, 6, 6, 8]])
data_values = np.array([[-3, -4, -5, 2], [2, 2, 4, 4], [2, -2, 4, 6],
[2, 6, 6, 8]])
data_t = helper.make_tensor(name='data',
data_type=TensorProto.FLOAT16,
dims=data_values.shape,
Expand Down Expand Up @@ -5710,7 +5711,8 @@ def int4_const_identity_qdq_test():

y_t = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [4, 4])

return ([i_node, q_node, dq_node, t_node], [], [y_t], [x_t, data_t, sc_t, sc_2_t])
return ([i_node, q_node, dq_node,
t_node], [], [y_t], [x_t, data_t, sc_t, sc_2_t])


@onnx_test()
Expand Down Expand Up @@ -5767,7 +5769,8 @@ def int4_const_identity_block_sz_1_qdq_test():

y_t = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [4, 2])

return ([i_node, q_node, dq_node, t_node], [], [y_t], [x_t, data_t, sc_t, sc_2_t])
return ([i_node, q_node, dq_node,
t_node], [], [y_t], [x_t, data_t, sc_t, sc_2_t])


@onnx_test()
Expand Down

0 comments on commit 7c643e6

Please sign in to comment.