Skip to content

Commit

Permalink
signed-int4 support for Pack/Unpack (#3359)
Browse files Browse the repository at this point in the history
  • Loading branch information
lakhinderwalia committed Sep 18, 2024
1 parent e25d77b commit 5e8b60a
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 59 deletions.
59 changes: 43 additions & 16 deletions src/include/migraphx/op/pack_int4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,35 +62,62 @@ struct pack_int4
{
check_shapes{inputs, *this}.same_dims().has(1);
auto in_shape = inputs.front();
if(in_shape.type() != migraphx::shape::uint8_type)
if(in_shape.type() != migraphx::shape::int8_type and
in_shape.type() != migraphx::shape::uint8_type)
{
MIGRAPHX_THROW("PACK_INT4: Only Unsigned Int8 type is supported for packing");
MIGRAPHX_THROW("PACK_INT4: Only Int8 or Uint8 is supported for packing");
}
auto new_lens = in_shape.lens();
if(new_lens[axis] % 2 != 0)
{
MIGRAPHX_THROW("PACK_INT4: Can not pack axis that has odd lengths");
}
new_lens[axis] /= 2;
return {migraphx::shape::uint8_type, new_lens};
return {in_shape.type(), new_lens};
}

argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto input = args.front();
auto in_shape = input.get_shape();

argument result{output_shape};
auto in_shape = args.front().get_shape();
auto input = args.at(0).get<uint8_t>();
auto output = result.get<uint8_t>();
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
auto in_data_multi_idx = data_idx;
in_data_multi_idx[axis] *= 2;
auto input_val = input[in_data_multi_idx];
// mask first 4 bits, keep it little endian.
output[i] = uint8_t(0x0F) & input_val;
in_data_multi_idx[axis] += 1;
input_val = input[in_data_multi_idx];
output[i] = (input_val << 4u) | output[i]; // NOLINT(hicpp-signed-bitwise)

visit_all(result, input)([&](auto out, auto inp) {
par_for(output_shape.elements(), [&](auto i) {
using type = typename decltype(inp)::value_type;
type min_4bit; // clip min value
type max_4bit; // clip max value

if constexpr(std::is_signed<type>{})
{
min_4bit = -8;
max_4bit = 7;
}
else
{
min_4bit = 0;
max_4bit = 15;
}

auto data_idx = output_shape.multi(i);
auto in_data_multi_idx = data_idx;
in_data_multi_idx[axis] *= 2;
type val1 = inp[in_data_multi_idx];
in_data_multi_idx[axis] += 1;
type val2 = inp[in_data_multi_idx];

// clip:
val1 = std::min(std::max(val1, min_4bit), max_4bit);
val2 = std::min(std::max(val2, min_4bit), max_4bit);

// pack:
// the bit operations are forced into uint8_t mode,
// and this would avoid compiler warnings as well.
uint8_t val_ui8_1 = static_cast<uint8_t>(val1);
uint8_t val_ui8_2 = static_cast<uint8_t>(val2);
out[i] = (val_ui8_2 << 4) | (val_ui8_1 & 0xf); // NOLINT(hicpp-signed-bitwise)
});
});
return result;
}
Expand Down
59 changes: 44 additions & 15 deletions src/include/migraphx/op/unpack_int4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,59 @@ struct unpack_int4
{
check_shapes{inputs, *this}.same_dims().has(1);
auto in_shape = inputs.front();
if(in_shape.type() != migraphx::shape::uint8_type)
if(in_shape.type() != migraphx::shape::int8_type and
in_shape.type() != migraphx::shape::uint8_type)
{
MIGRAPHX_THROW("UNPACK_INT4: Only Unsigned Int8 type is supported for unpacking");
MIGRAPHX_THROW("UNPACK_INT4: Only Int8 or Uint8 is supported for unpacking");
}
auto new_lens = in_shape.lens();
new_lens[axis] *= 2;
return {migraphx::shape::uint8_type, new_lens};
return {in_shape.type(), new_lens};
}

argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto input = args.front();
auto in_shape = input.get_shape();

argument result{output_shape};
auto in_shape = args.front().get_shape();
auto input = args.at(0).get<uint8_t>();
auto output = result.get<uint8_t>();
par_for(in_shape.elements(), [&](auto i) {
auto data_idx = in_shape.multi(i);
auto out_data_multi_idx = data_idx;
out_data_multi_idx[axis] *= 2;
auto input_val = input[data_idx];
// mask first 4 bits, packing is assumed to be little endian
output[out_data_multi_idx] = uint8_t(0x0F) & input_val;
out_data_multi_idx[axis] += 1;
output[out_data_multi_idx] = input_val >> 4; // NOLINT(hicpp-signed-bitwise)

visit_all(result, input)([&](auto out, auto inp) {
par_for(in_shape.elements(), [&](auto i) {
using type = typename decltype(out)::value_type;
auto data_idx = in_shape.multi(i);
data_idx[axis] *= 2;
if constexpr(std::is_signed<type>{})
{
// signed input: [Most significant nibble | Least significant nibble]
int8_t val1 = inp[i];
int8_t val2 = val1;

// Step1: move the LSN to MSN:
// However avoid doing a left shift of signed quantity
// due to its possible run time error.
uint8_t u_tmp = static_cast<uint8_t>(val1);
u_tmp <<= 4; // NOLINT(hicpp-signed-bitwise)
val1 = static_cast<int8_t>(u_tmp);

// Step2: the sign bit is copied in a right signed-shift:
val1 >>= 4; // NOLINT(hicpp-signed-bitwise)
out[data_idx] = val1;

data_idx[axis] += 1;
val2 >>= 4; // NOLINT(hicpp-signed-bitwise)
out[data_idx] = val2;
}
else
{
// unpacking of 2 unsigned nibbles:
uint8_t val = inp[i];
out[data_idx] = val & 0xf; // NOLINT(hicpp-signed-bitwise)

data_idx[axis] += 1;
out[data_idx] = val >> 4; // NOLINT(hicpp-signed-bitwise)
}
});
});
return result;
}
Expand Down
47 changes: 32 additions & 15 deletions test/ref/pack_int4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

#include <test.hpp>

TEST_CASE(pack_int4)
TEST_CASE(pack_uint4)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -45,7 +45,23 @@ TEST_CASE(pack_int4)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_int4_transposed)
TEST_CASE(pack_int4)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int8_type, {2, 2}};
auto l0 = mm->add_literal(migraphx::literal{s, {0x0A, 0x0B, 0x0C, 0x0D}});
mm->add_instruction(migraphx::make_op("pack_int4"), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint8_t> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold{0x77, 0x77};
// clipped values
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_uint4_transposed)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -60,7 +76,7 @@ TEST_CASE(pack_int4_transposed)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_int4_broadcasted)
TEST_CASE(pack_uint4_broadcasted)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -76,7 +92,7 @@ TEST_CASE(pack_int4_broadcasted)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_int4_axis_0)
TEST_CASE(pack_uint4_axis_0)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -91,17 +107,18 @@ TEST_CASE(pack_int4_axis_0)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_int4_nchw)
TEST_CASE(pack_uint4_nchw)
{
// test with literal values such as 0x18 in which first 4 bits will be dropped, ideally
// quantizer should produce values that fit into 4 bits.
// input values >= 0x10 would be clipped to 0xf (the maximum for uint4)
// As seen in the bottom half of the expected results (gold) below.
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::uint8_type, {1, 2, 4, 4}};
auto l0 = mm->add_literal(
migraphx::literal{s, {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F}});

mm->add_instruction(migraphx::make_op("pack_int4", {{"axis", -1}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
Expand All @@ -115,13 +132,13 @@ TEST_CASE(pack_int4_nchw)
0xBA,
0xDC,
0xFE,
0x10,
0x32,
0x54,
0x76,
0x98,
0xBA,
0xDC,
0xFE};
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
48 changes: 35 additions & 13 deletions test/ref/pack_unpack_int4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

#include <test.hpp>

TEST_CASE(pack_unpack_int4)
TEST_CASE(pack_unpack_uint4)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -40,13 +40,13 @@ TEST_CASE(pack_unpack_int4)
mm->add_instruction(migraphx::make_op("unpack_int4"), pack_ins);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint8_t> results_vector(4);
std::vector<uint8_t> results_vector(s.elements());
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold{0x0A, 0x0B, 0x0C, 0x0D};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_unpack_int4_transposed)
TEST_CASE(pack_unpack_uint4_transposed)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -56,13 +56,13 @@ TEST_CASE(pack_unpack_int4_transposed)
mm->add_instruction(migraphx::make_op("unpack_int4"), pack_ins);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint8_t> results_vector(4);
std::vector<uint8_t> results_vector(s.elements());
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold{0x0A, 0x0B, 0x0C, 0x0D};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_multibroadcast_unpack_int4)
TEST_CASE(pack_multibroadcast_unpack_uint4)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand Down Expand Up @@ -95,7 +95,7 @@ TEST_CASE(pack_multibroadcast_unpack_int4)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_unpack_int4_axis_0)
TEST_CASE(pack_unpack_uint4_axis_0)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -105,16 +105,14 @@ TEST_CASE(pack_unpack_int4_axis_0)
mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", 0}}), pack_ins);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint8_t> results_vector(4);
std::vector<uint8_t> results_vector(s.elements());
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold{0x0A, 0x0B, 0x0C, 0x0D};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_unpack_int4_nchw)
TEST_CASE(pack_unpack_uint4_nchw)
{
// test with literal values such as 0x18 in which first 4 bits will be dropped, ideally
// quantizer should produce values that fit into 4 bits.
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::uint8_type, {1, 2, 4, 4}};
Expand All @@ -123,13 +121,37 @@ TEST_CASE(pack_unpack_int4_nchw)
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F}});
auto pack_ins = mm->add_instruction(migraphx::make_op("pack_int4", {{"axis", -1}}), l0);
// The result of packing also includes a clip. Max clip value = 0xf for UINT4.
mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", -1}}), pack_ins);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint8_t> results_vector(32);
std::vector<uint8_t> results_vector(s.elements());
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
// The gold unpacked values should contain input values clipped during pack
std::vector<uint8_t> gold{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_unpack_int4_nchw)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int8_type, {1, 2, 2, 4}};
auto l0 = mm->add_literal(
migraphx::literal{s, {-10, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 10}});

auto pack_ins = mm->add_instruction(migraphx::make_op("pack_int4", {{"axis", -1}}), l0);
// Packing also includes a clip:
// Max clipped and packed nibble value = +7 for INT4.
// Min clipped and packed nibble value = -8 for INT4.
// Both the outer values of l0 should be clipped during pack_int4: -10, 10
mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", -1}}), pack_ins);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int8_t> results_vector(s.elements());
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold{-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

0 comments on commit 5e8b60a

Please sign in to comment.