From 5e8b60aa7439481fdd442fde617c72b8bd85fcc0 Mon Sep 17 00:00:00 2001 From: Lakhinder Walia <139581206+lakhinderwalia@users.noreply.github.com> Date: Wed, 18 Sep 2024 05:57:22 -0700 Subject: [PATCH] signed-int4 support for Pack/Unpack (#3359) --- src/include/migraphx/op/pack_int4.hpp | 59 ++++++++++++++++++------- src/include/migraphx/op/unpack_int4.hpp | 59 ++++++++++++++++++------- test/ref/pack_int4.cpp | 47 +++++++++++++------- test/ref/pack_unpack_int4.cpp | 48 ++++++++++++++------ 4 files changed, 154 insertions(+), 59 deletions(-) diff --git a/src/include/migraphx/op/pack_int4.hpp b/src/include/migraphx/op/pack_int4.hpp index c00931d8c6b..69d8ef71a47 100644 --- a/src/include/migraphx/op/pack_int4.hpp +++ b/src/include/migraphx/op/pack_int4.hpp @@ -62,9 +62,10 @@ struct pack_int4 { check_shapes{inputs, *this}.same_dims().has(1); auto in_shape = inputs.front(); - if(in_shape.type() != migraphx::shape::uint8_type) + if(in_shape.type() != migraphx::shape::int8_type and + in_shape.type() != migraphx::shape::uint8_type) { - MIGRAPHX_THROW("PACK_INT4: Only Unsigned Int8 type is supported for packing"); + MIGRAPHX_THROW("PACK_INT4: Only Int8 or Uint8 is supported for packing"); } auto new_lens = in_shape.lens(); if(new_lens[axis] % 2 != 0) @@ -72,25 +73,51 @@ struct pack_int4 MIGRAPHX_THROW("PACK_INT4: Can not pack axis that has odd lengths"); } new_lens[axis] /= 2; - return {migraphx::shape::uint8_type, new_lens}; + return {in_shape.type(), new_lens}; } argument compute(const shape& output_shape, std::vector args) const { + auto input = args.front(); + auto in_shape = input.get_shape(); + argument result{output_shape}; - auto in_shape = args.front().get_shape(); - auto input = args.at(0).get(); - auto output = result.get(); - par_for(output_shape.elements(), [&](auto i) { - auto data_idx = output_shape.multi(i); - auto in_data_multi_idx = data_idx; - in_data_multi_idx[axis] *= 2; - auto input_val = input[in_data_multi_idx]; - // mask first 4 bits, keep it little endian. - output[i] = uint8_t(0x0F) & input_val; - in_data_multi_idx[axis] += 1; - input_val = input[in_data_multi_idx]; - output[i] = (input_val << 4u) | output[i]; // NOLINT(hicpp-signed-bitwise) + + visit_all(result, input)([&](auto out, auto inp) { + par_for(output_shape.elements(), [&](auto i) { + using type = typename decltype(inp)::value_type; + type min_4bit; // clip min value + type max_4bit; // clip max value + + if constexpr(std::is_signed{}) + { + min_4bit = -8; + max_4bit = 7; + } + else + { + min_4bit = 0; + max_4bit = 15; + } + + auto data_idx = output_shape.multi(i); + auto in_data_multi_idx = data_idx; + in_data_multi_idx[axis] *= 2; + type val1 = inp[in_data_multi_idx]; + in_data_multi_idx[axis] += 1; + type val2 = inp[in_data_multi_idx]; + + // clip: + val1 = std::min(std::max(val1, min_4bit), max_4bit); + val2 = std::min(std::max(val2, min_4bit), max_4bit); + + // pack: + // the bit operations are forced into uint8_t mode, + // and this would avoid compiler warnings as well. + uint8_t val_ui8_1 = static_cast(val1); + uint8_t val_ui8_2 = static_cast(val2); + out[i] = (val_ui8_2 << 4) | (val_ui8_1 & 0xf); // NOLINT(hicpp-signed-bitwise) + }); }); return result; } diff --git a/src/include/migraphx/op/unpack_int4.hpp b/src/include/migraphx/op/unpack_int4.hpp index df7938ffea2..f8e44654bb8 100644 --- a/src/include/migraphx/op/unpack_int4.hpp +++ b/src/include/migraphx/op/unpack_int4.hpp @@ -62,30 +62,59 @@ struct unpack_int4 { check_shapes{inputs, *this}.same_dims().has(1); auto in_shape = inputs.front(); - if(in_shape.type() != migraphx::shape::uint8_type) + if(in_shape.type() != migraphx::shape::int8_type and + in_shape.type() != migraphx::shape::uint8_type) { - MIGRAPHX_THROW("UNPACK_INT4: Only Unsigned Int8 type is supported for unpacking"); + MIGRAPHX_THROW("UNPACK_INT4: Only Int8 or Uint8 is supported for unpacking"); } auto new_lens = in_shape.lens(); new_lens[axis] *= 2; - return {migraphx::shape::uint8_type, new_lens}; + return {in_shape.type(), new_lens}; } argument compute(const shape& output_shape, std::vector args) const { + auto input = args.front(); + auto in_shape = input.get_shape(); + argument result{output_shape}; - auto in_shape = args.front().get_shape(); - auto input = args.at(0).get(); - auto output = result.get(); - par_for(in_shape.elements(), [&](auto i) { - auto data_idx = in_shape.multi(i); - auto out_data_multi_idx = data_idx; - out_data_multi_idx[axis] *= 2; - auto input_val = input[data_idx]; - // mask first 4 bits, packing is assumed to be little endian - output[out_data_multi_idx] = uint8_t(0x0F) & input_val; - out_data_multi_idx[axis] += 1; - output[out_data_multi_idx] = input_val >> 4; // NOLINT(hicpp-signed-bitwise) + + visit_all(result, input)([&](auto out, auto inp) { + par_for(in_shape.elements(), [&](auto i) { + using type = typename decltype(out)::value_type; + auto data_idx = in_shape.multi(i); + data_idx[axis] *= 2; + if constexpr(std::is_signed{}) + { + // signed input: [Most significant nibble | Least significant nibble] + int8_t val1 = inp[i]; + int8_t val2 = val1; + + // Step1: move the LSN to MSN: + // However avoid doing a left shift of signed quantity + // due to its possible run time error. + uint8_t u_tmp = static_cast(val1); + u_tmp <<= 4; // NOLINT(hicpp-signed-bitwise) + val1 = static_cast(u_tmp); + + // Step2: the sign bit is copied in a right signed-shift: + val1 >>= 4; // NOLINT(hicpp-signed-bitwise) + out[data_idx] = val1; + + data_idx[axis] += 1; + val2 >>= 4; // NOLINT(hicpp-signed-bitwise) + out[data_idx] = val2; + } + else + { + // unpacking of 2 unsigned nibbles: + uint8_t val = inp[i]; + out[data_idx] = val & 0xf; // NOLINT(hicpp-signed-bitwise) + + data_idx[axis] += 1; + out[data_idx] = val >> 4; // NOLINT(hicpp-signed-bitwise) + } + }); }); return result; } diff --git a/test/ref/pack_int4.cpp b/test/ref/pack_int4.cpp index c1e7cc509ac..797fb5f9bfc 100644 --- a/test/ref/pack_int4.cpp +++ b/test/ref/pack_int4.cpp @@ -30,7 +30,7 @@ #include -TEST_CASE(pack_int4) +TEST_CASE(pack_uint4) { migraphx::program p; auto* mm = p.get_main_module(); @@ -45,7 +45,23 @@ TEST_CASE(pack_int4) EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } -TEST_CASE(pack_int4_transposed) +TEST_CASE(pack_int4) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int8_type, {2, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s, {0x0A, 0x0B, 0x0C, 0x0D}}); + mm->add_instruction(migraphx::make_op("pack_int4"), l0); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(2); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0x77, 0x77}; + // clipped values + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(pack_uint4_transposed) { migraphx::program p; auto* mm = p.get_main_module(); @@ -60,7 +76,7 @@ TEST_CASE(pack_int4_transposed) EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } -TEST_CASE(pack_int4_broadcasted) +TEST_CASE(pack_uint4_broadcasted) { migraphx::program p; auto* mm = p.get_main_module(); @@ -76,7 +92,7 @@ TEST_CASE(pack_int4_broadcasted) EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } -TEST_CASE(pack_int4_axis_0) +TEST_CASE(pack_uint4_axis_0) { migraphx::program p; auto* mm = p.get_main_module(); @@ -91,10 +107,10 @@ TEST_CASE(pack_int4_axis_0) EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } -TEST_CASE(pack_int4_nchw) +TEST_CASE(pack_uint4_nchw) { - // test with literal values such as 0x18 in which first 4 bits will be dropped, ideally - // quantizer should produce values that fit into 4 bits. + // input values >= 0x10 would be clipped to 0xf (the maximum for uint4) + // As seen in the bottom half of the expected results (gold) below. migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::uint8_type, {1, 2, 4, 4}}; @@ -102,6 +118,7 @@ TEST_CASE(pack_int4_nchw) migraphx::literal{s, {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F}}); + mm->add_instruction(migraphx::make_op("pack_int4", {{"axis", -1}}), l0); p.compile(migraphx::make_target("ref")); auto result = p.eval({}).back(); @@ -115,13 +132,13 @@ TEST_CASE(pack_int4_nchw) 0xBA, 0xDC, 0xFE, - 0x10, - 0x32, - 0x54, - 0x76, - 0x98, - 0xBA, - 0xDC, - 0xFE}; + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF}; EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } diff --git a/test/ref/pack_unpack_int4.cpp b/test/ref/pack_unpack_int4.cpp index e79555855ef..e40f38c9b5f 100644 --- a/test/ref/pack_unpack_int4.cpp +++ b/test/ref/pack_unpack_int4.cpp @@ -30,7 +30,7 @@ #include -TEST_CASE(pack_unpack_int4) +TEST_CASE(pack_unpack_uint4) { migraphx::program p; auto* mm = p.get_main_module(); @@ -40,13 +40,13 @@ TEST_CASE(pack_unpack_int4) mm->add_instruction(migraphx::make_op("unpack_int4"), pack_ins); p.compile(migraphx::make_target("ref")); auto result = p.eval({}).back(); - std::vector results_vector(4); + std::vector results_vector(s.elements()); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); std::vector gold{0x0A, 0x0B, 0x0C, 0x0D}; EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } -TEST_CASE(pack_unpack_int4_transposed) +TEST_CASE(pack_unpack_uint4_transposed) { migraphx::program p; auto* mm = p.get_main_module(); @@ -56,13 +56,13 @@ TEST_CASE(pack_unpack_int4_transposed) mm->add_instruction(migraphx::make_op("unpack_int4"), pack_ins); p.compile(migraphx::make_target("ref")); auto result = p.eval({}).back(); - std::vector results_vector(4); + std::vector results_vector(s.elements()); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); std::vector gold{0x0A, 0x0B, 0x0C, 0x0D}; EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } -TEST_CASE(pack_multibroadcast_unpack_int4) +TEST_CASE(pack_multibroadcast_unpack_uint4) { migraphx::program p; auto* mm = p.get_main_module(); @@ -95,7 +95,7 @@ TEST_CASE(pack_multibroadcast_unpack_int4) EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } -TEST_CASE(pack_unpack_int4_axis_0) +TEST_CASE(pack_unpack_uint4_axis_0) { migraphx::program p; auto* mm = p.get_main_module(); @@ -105,16 +105,14 @@ TEST_CASE(pack_unpack_int4_axis_0) mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", 0}}), pack_ins); p.compile(migraphx::make_target("ref")); auto result = p.eval({}).back(); - std::vector results_vector(4); + std::vector results_vector(s.elements()); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); std::vector gold{0x0A, 0x0B, 0x0C, 0x0D}; EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } -TEST_CASE(pack_unpack_int4_nchw) +TEST_CASE(pack_unpack_uint4_nchw) { - // test with literal values such as 0x18 in which first 4 bits will be dropped, ideally - // quantizer should produce values that fit into 4 bits. migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::uint8_type, {1, 2, 4, 4}}; @@ -123,13 +121,37 @@ TEST_CASE(pack_unpack_int4_nchw) 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F}}); auto pack_ins = mm->add_instruction(migraphx::make_op("pack_int4", {{"axis", -1}}), l0); + // The result of packing also includes a clip. Max clip value = 0xf for UINT4. mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", -1}}), pack_ins); p.compile(migraphx::make_target("ref")); auto result = p.eval({}).back(); - std::vector results_vector(32); + std::vector results_vector(s.elements()); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + // The gold unpacked values should contain input values clipped during pack std::vector gold{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, - 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, - 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F}; + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, + 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(pack_unpack_int4_nchw) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int8_type, {1, 2, 2, 4}}; + auto l0 = mm->add_literal( + migraphx::literal{s, {-10, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 10}}); + + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_int4", {{"axis", -1}}), l0); + // Packing also includes a clip: + // Max clipped and packed nibble value = +7 for INT4. + // Min clipped and packed nibble value = -8 for INT4. + // Both the outer values of l0 should be clipped during pack_int4: -10, 10 + mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", -1}}), pack_ins); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(s.elements()); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7}; EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); }