From 7c643e6e487ca6737cc6d5ab937f73fc6fdaba19 Mon Sep 17 00:00:00 2001 From: lakhinderwalia Date: Tue, 17 Sep 2024 16:17:52 -0700 Subject: [PATCH] handle review comments --- src/quantize_int4.cpp | 4 ++-- src/targets/gpu/target.cpp | 9 +-------- test/gpu/mlir.cpp | 8 +++++++- test/onnx/gen_onnx.py | 9 ++++++--- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/quantize_int4.cpp b/src/quantize_int4.cpp index 326790e8e2..5da7cf16c8 100644 --- a/src/quantize_int4.cpp +++ b/src/quantize_int4.cpp @@ -67,8 +67,8 @@ static void int4_quantize_module(module& m) if(not inp->can_eval()) return inp; - std::vector val; - inp->eval().visit([&](auto in_data) { val.assign(in_data.begin(), in_data.end()); }); + std::vector val = inp->eval().to_vector(); + auto [min, max] = std::minmax_element(val.begin(), val.end()); *min = *min > 0 ? 0 : *min; *max = *max < 0 ? 0 : *max; diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 9aa44a541d..b190d11a40 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -41,7 +41,6 @@ #include #include #include -#include #include #include #include @@ -162,20 +161,14 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, normalize_ops{}, dead_code_elimination{}, - // 1. For int4 ops: first remove the identity op next to packed int4 data. - eliminate_identity{}, - // 2. Next: (pack/unpack)_int4 handling in simplyfy_qdq. simplify_qdq{}, - // 3. Next: pruning of pack_int4 related const branches. - propagate_constant{{}}, - // 4. Last for int4: dce. - dead_code_elimination{}, enable_pass(not mlir_enabled(), rewrite_quantization{}), dead_code_elimination{}, // workaround for rocBLAS unsupported error when using uint8 in quant_dot, quant_convolution & pooling eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_convolution", "quant_dot", "pooling"}}, eliminate_data_type{unsupported_types, shape::type_t::float_type}, simplify_reshapes{}, + eliminate_identity{}, eliminate_pad{}, dead_code_elimination{}, insert_pad{{"convolution"}}, diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index 3cab97b509..67e78697a3 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -483,7 +483,13 @@ module { TEST_CASE(int4_unpack_conv) { std::string mlir_output = R"__migraphx__( - module { func.func @mlir_unpack_int4_quant_convolution(%arg0: !migraphx.shaped<2x8x2x1xi8, 16x2x1x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>) -> !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { %0 = migraphx.unpack %arg0 {axis = 3 : i64, isUnsigned = false} : <2x8x2x1xi8, 16x2x1x1> -> <2x8x2x2xi8, 32x4x2x1> %1 = migraphx.quant_convolution %arg1, %0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x2x2xi8, 32x4x2x1> -> <1x2x3x3xi32, 18x9x3x1> return %1 : !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> } } + module { + func.func @mlir_unpack_int4_quant_convolution(%arg0: !migraphx.shaped<2x8x2x1xi8, 16x2x1x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>) -> !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { + %0 = migraphx.unpack %arg0 {axis = 3 : i64, isUnsigned = false} : <2x8x2x1xi8, 16x2x1x1> -> <2x8x2x2xi8, 32x4x2x1> + %1 = migraphx.quant_convolution %arg1, %0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x2x2xi8, 32x4x2x1> -> <1x2x3x3xi32, 18x9x3x1> + return %1 : !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> + } +} )__migraphx__"; migraphx::module m; auto x = m.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 4, 4}}); diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 0d82ac4f30..7b05f961b8 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -5671,7 +5671,8 @@ def int4_const_identity_qdq_test(): outputs=['i_y_zp'], ) - data_values = np.array([[-3, -4, -5, 2], [2, 2, 4, 4], [2, -2, 4, 6], [2, 6, 6, 8]]) + data_values = np.array([[-3, -4, -5, 2], [2, 2, 4, 4], [2, -2, 4, 6], + [2, 6, 6, 8]]) data_t = helper.make_tensor(name='data', data_type=TensorProto.FLOAT16, dims=data_values.shape, @@ -5710,7 +5711,8 @@ def int4_const_identity_qdq_test(): y_t = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [4, 4]) - return ([i_node, q_node, dq_node, t_node], [], [y_t], [x_t, data_t, sc_t, sc_2_t]) + return ([i_node, q_node, dq_node, + t_node], [], [y_t], [x_t, data_t, sc_t, sc_2_t]) @onnx_test() @@ -5767,7 +5769,8 @@ def int4_const_identity_block_sz_1_qdq_test(): y_t = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [4, 2]) - return ([i_node, q_node, dq_node, t_node], [], [y_t], [x_t, data_t, sc_t, sc_2_t]) + return ([i_node, q_node, dq_node, + t_node], [], [y_t], [x_t, data_t, sc_t, sc_2_t]) @onnx_test()