Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

signed-int4 support for Pack/Unpack #3359

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions src/include/migraphx/op/pack_int4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,35 +62,62 @@
{
check_shapes{inputs, *this}.same_dims().has(1);
auto in_shape = inputs.front();
if(in_shape.type() != migraphx::shape::uint8_type)
if(in_shape.type() != migraphx::shape::int8_type and
in_shape.type() != migraphx::shape::uint8_type)
{
MIGRAPHX_THROW("PACK_INT4: Only Unsigned Int8 type is supported for packing");
MIGRAPHX_THROW("PACK_INT4: Only Int8 or Uint8 is supported for packing");
}
auto new_lens = in_shape.lens();
if(new_lens[axis] % 2 != 0)
{
MIGRAPHX_THROW("PACK_INT4: Can not pack axis that has odd lengths");
}
new_lens[axis] /= 2;
return {migraphx::shape::uint8_type, new_lens};
return {in_shape.type(), new_lens};
}

argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto input = args.front();
auto in_shape = input.get_shape();

argument result{output_shape};
auto in_shape = args.front().get_shape();
auto input = args.at(0).get<uint8_t>();
auto output = result.get<uint8_t>();
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
auto in_data_multi_idx = data_idx;
in_data_multi_idx[axis] *= 2;
auto input_val = input[in_data_multi_idx];
// mask first 4 bits, keep it little endian.
output[i] = uint8_t(0x0F) & input_val;
in_data_multi_idx[axis] += 1;
input_val = input[in_data_multi_idx];
output[i] = (input_val << 4u) | output[i]; // NOLINT(hicpp-signed-bitwise)

visit_all(result, input)([&](auto out, auto inp) {
par_for(output_shape.elements(), [&](auto i) {
using type = typename decltype(inp)::value_type;
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
type min_4bit; // clip min value
type max_4bit; // clip max value

Check warning on line 90 in src/include/migraphx/op/pack_int4.hpp

View check run for this annotation

Codecov / codecov/patch

src/include/migraphx/op/pack_int4.hpp#L89-L90

Added lines #L89 - L90 were not covered by tests

if constexpr(std::is_signed<type>{})
{
min_4bit = -8;
max_4bit = 7;
}
else
{
min_4bit = 0;
max_4bit = 15;
}

auto data_idx = output_shape.multi(i);
auto in_data_multi_idx = data_idx;
in_data_multi_idx[axis] *= 2;
type val1 = inp[in_data_multi_idx];
in_data_multi_idx[axis] += 1;
type val2 = inp[in_data_multi_idx];

// clip:
val1 = std::min(std::max(val1, min_4bit), max_4bit);
val2 = std::min(std::max(val2, min_4bit), max_4bit);

Check warning on line 112 in src/include/migraphx/op/pack_int4.hpp

View check run for this annotation

Codecov / codecov/patch

src/include/migraphx/op/pack_int4.hpp#L112

Added line #L112 was not covered by tests

// pack:
// the bit operations are forced into uint8_t mode,
// and this would avoid compiler warnings as well.
uint8_t val_ui8_1 = static_cast<uint8_t>(val1);

Check warning on line 117 in src/include/migraphx/op/pack_int4.hpp

View check run for this annotation

Codecov / codecov/patch

src/include/migraphx/op/pack_int4.hpp#L117

Added line #L117 was not covered by tests
uint8_t val_ui8_2 = static_cast<uint8_t>(val2);
out[i] = (val_ui8_2 << 4) | (val_ui8_1 & 0xf); // NOLINT(hicpp-signed-bitwise)
});
});
return result;
}
Expand Down
59 changes: 44 additions & 15 deletions src/include/migraphx/op/unpack_int4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,59 @@ struct unpack_int4
{
check_shapes{inputs, *this}.same_dims().has(1);
auto in_shape = inputs.front();
if(in_shape.type() != migraphx::shape::uint8_type)
if(in_shape.type() != migraphx::shape::int8_type and
in_shape.type() != migraphx::shape::uint8_type)
{
MIGRAPHX_THROW("UNPACK_INT4: Only Unsigned Int8 type is supported for unpacking");
MIGRAPHX_THROW("UNPACK_INT4: Only Int8 or Uint8 is supported for unpacking");
}
auto new_lens = in_shape.lens();
new_lens[axis] *= 2;
return {migraphx::shape::uint8_type, new_lens};
return {in_shape.type(), new_lens};
}

argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto input = args.front();
auto in_shape = input.get_shape();

argument result{output_shape};
auto in_shape = args.front().get_shape();
auto input = args.at(0).get<uint8_t>();
auto output = result.get<uint8_t>();
par_for(in_shape.elements(), [&](auto i) {
auto data_idx = in_shape.multi(i);
auto out_data_multi_idx = data_idx;
out_data_multi_idx[axis] *= 2;
auto input_val = input[data_idx];
// mask first 4 bits, packing is assumed to be little endian
output[out_data_multi_idx] = uint8_t(0x0F) & input_val;
out_data_multi_idx[axis] += 1;
output[out_data_multi_idx] = input_val >> 4; // NOLINT(hicpp-signed-bitwise)

visit_all(result, input)([&](auto out, auto inp) {
par_for(in_shape.elements(), [&](auto i) {
using type = typename decltype(out)::value_type;
auto data_idx = in_shape.multi(i);
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
data_idx[axis] *= 2;
if constexpr(std::is_signed<type>{})
{
// signed input: [Most significant nibble | Least significant nibble]
int8_t val1 = inp[i];
int8_t val2 = val1;

// Step1: move the LSN to MSN:
// However avoid doing a left shift of signed quantity
// due to its possible run time error.
uint8_t u_tmp = static_cast<uint8_t>(val1);
u_tmp <<= 4; // NOLINT(hicpp-signed-bitwise)
val1 = static_cast<int8_t>(u_tmp);

// Step2: the sign bit is copied in a right signed-shift:
val1 >>= 4; // NOLINT(hicpp-signed-bitwise)
out[data_idx] = val1;

data_idx[axis] += 1;
val2 >>= 4; // NOLINT(hicpp-signed-bitwise)
out[data_idx] = val2;
}
else
{
// unpacking of 2 unsigned nibbles:
uint8_t val = inp[i];
out[data_idx] = val & 0xf; // NOLINT(hicpp-signed-bitwise)

data_idx[axis] += 1;
out[data_idx] = val >> 4; // NOLINT(hicpp-signed-bitwise)
}
});
});
return result;
}
Expand Down
47 changes: 32 additions & 15 deletions test/ref/pack_int4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

#include <test.hpp>

TEST_CASE(pack_int4)
TEST_CASE(pack_uint4)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -45,7 +45,23 @@ TEST_CASE(pack_int4)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_int4_transposed)
TEST_CASE(pack_int4)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int8_type, {2, 2}};
auto l0 = mm->add_literal(migraphx::literal{s, {0x0A, 0x0B, 0x0C, 0x0D}});
mm->add_instruction(migraphx::make_op("pack_int4"), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint8_t> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold{0x77, 0x77};
// clipped values
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_uint4_transposed)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -60,7 +76,7 @@ TEST_CASE(pack_int4_transposed)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_int4_broadcasted)
TEST_CASE(pack_uint4_broadcasted)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -76,7 +92,7 @@ TEST_CASE(pack_int4_broadcasted)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_int4_axis_0)
TEST_CASE(pack_uint4_axis_0)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -91,17 +107,18 @@ TEST_CASE(pack_int4_axis_0)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_int4_nchw)
TEST_CASE(pack_uint4_nchw)
{
// test with literal values such as 0x18 in which first 4 bits will be dropped, ideally
// quantizer should produce values that fit into 4 bits.
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
// input values >= 0x10 would be clipped to 0xf (the maximum for uint4)
// As seen in the bottom half of the expected results (gold) below.
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::uint8_type, {1, 2, 4, 4}};
auto l0 = mm->add_literal(
migraphx::literal{s, {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F}});

mm->add_instruction(migraphx::make_op("pack_int4", {{"axis", -1}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
Expand All @@ -115,13 +132,13 @@ TEST_CASE(pack_int4_nchw)
0xBA,
0xDC,
0xFE,
0x10,
0x32,
0x54,
0x76,
0x98,
0xBA,
0xDC,
0xFE};
0xFF,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add tests "ref" tests for "unpack" as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There already is a test for pack_unpack. Isn't that enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify further, there are ref tests that check for pack operations, and then followed by its unpack counterpart. And those initial and final values are compared appropriately. Thanks.

0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
48 changes: 35 additions & 13 deletions test/ref/pack_unpack_int4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

#include <test.hpp>

TEST_CASE(pack_unpack_int4)
TEST_CASE(pack_unpack_uint4)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -40,13 +40,13 @@ TEST_CASE(pack_unpack_int4)
mm->add_instruction(migraphx::make_op("unpack_int4"), pack_ins);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint8_t> results_vector(4);
std::vector<uint8_t> results_vector(s.elements());
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold{0x0A, 0x0B, 0x0C, 0x0D};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_unpack_int4_transposed)
TEST_CASE(pack_unpack_uint4_transposed)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -56,13 +56,13 @@ TEST_CASE(pack_unpack_int4_transposed)
mm->add_instruction(migraphx::make_op("unpack_int4"), pack_ins);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint8_t> results_vector(4);
std::vector<uint8_t> results_vector(s.elements());
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold{0x0A, 0x0B, 0x0C, 0x0D};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_multibroadcast_unpack_int4)
TEST_CASE(pack_multibroadcast_unpack_uint4)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand Down Expand Up @@ -95,7 +95,7 @@ TEST_CASE(pack_multibroadcast_unpack_int4)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_unpack_int4_axis_0)
TEST_CASE(pack_unpack_uint4_axis_0)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -105,16 +105,14 @@ TEST_CASE(pack_unpack_int4_axis_0)
mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", 0}}), pack_ins);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint8_t> results_vector(4);
std::vector<uint8_t> results_vector(s.elements());
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold{0x0A, 0x0B, 0x0C, 0x0D};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_unpack_int4_nchw)
TEST_CASE(pack_unpack_uint4_nchw)
{
// test with literal values such as 0x18 in which first 4 bits will be dropped, ideally
// quantizer should produce values that fit into 4 bits.
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::uint8_type, {1, 2, 4, 4}};
Expand All @@ -123,13 +121,37 @@ TEST_CASE(pack_unpack_int4_nchw)
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F}});
auto pack_ins = mm->add_instruction(migraphx::make_op("pack_int4", {{"axis", -1}}), l0);
// The result of packing also includes a clip. Max clip value = 0xf for UINT4.
mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", -1}}), pack_ins);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint8_t> results_vector(32);
std::vector<uint8_t> results_vector(s.elements());
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
// The gold unpacked values should contain input values clipped during pack
std::vector<uint8_t> gold{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(pack_unpack_int4_nchw)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int8_type, {1, 2, 2, 4}};
auto l0 = mm->add_literal(
migraphx::literal{s, {-10, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 10}});

auto pack_ins = mm->add_instruction(migraphx::make_op("pack_int4", {{"axis", -1}}), l0);
// Packing also includes a clip:
// Max clipped and packed nibble value = +7 for INT4.
// Min clipped and packed nibble value = -8 for INT4.
// Both the outer values of l0 should be clipped during pack_int4: -10, 10
mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", -1}}), pack_ins);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int8_t> results_vector(s.elements());
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold{-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
Loading