Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parse onnx graph for int4 #3408

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions src/onnx/onnx_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,11 @@
mod->debug_print(added_instructions);
}

static bool is_type_packed_int4(const onnx::TensorProto& t)
{
return t.data_type() == onnx::TensorProto::INT4 or t.data_type() == onnx::TensorProto::UINT4;
}

std::unordered_map<std::string, instruction_ref>
parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto& graph)
{
Expand All @@ -329,7 +334,13 @@
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
std::cout << "initializer: " << f.name() << std::endl;
// backup instructions in parent mod
mod_insts[f.name()] = mod->add_literal(parser.parse_tensor(f));
auto pt = parser.parse_tensor(f);
auto lit = mod->add_literal(pt);

if(is_type_packed_int4(f))
lit = mod->add_instruction(migraphx::make_op("unpack_int4"), lit);

Check warning on line 341 in src/onnx/onnx_parser.cpp

View check run for this annotation

Codecov / codecov/patch

src/onnx/onnx_parser.cpp#L341

Added line #L341 was not covered by tests

mod_insts[f.name()] = lit;
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
mod->debug_print(mod_insts[f.name()]);
}
Expand Down Expand Up @@ -362,7 +373,7 @@
if(parser.map_input_dims.count(name) > 0)
{
std::vector<std::size_t> dims = parser.map_input_dims.at(name);
s = parser.parse_type(input.type(), dims);
s = parser.parse_type(input.type(), dims);
}
else if(parser.map_dyn_input_dims.count(name) > 0)
{
Expand Down Expand Up @@ -493,12 +504,27 @@
MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type()));
}

literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
static shape parse_tensor_shape(const onnx::TensorProto& t)
{
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
auto type = get_type(t.data_type());
shape tensor_shape(type, dims);
if(is_type_packed_int4(t))
{
auto dim_n = dims.back();
if(dim_n > 0 and (dim_n % 2 == 0))
dims.back() = dim_n / 2; // int4-packed dimension converted to int8-sized units

Check warning on line 514 in src/onnx/onnx_parser.cpp

View check run for this annotation

Codecov / codecov/patch

src/onnx/onnx_parser.cpp#L512-L514

Added lines #L512 - L514 were not covered by tests
else
MIGRAPHX_THROW("Int4: currently supports only even-sized packed tensors");

Check warning on line 516 in src/onnx/onnx_parser.cpp

View check run for this annotation

Codecov / codecov/patch

src/onnx/onnx_parser.cpp#L516

Added line #L516 was not covered by tests
}
return shape{get_type(t.data_type()), dims};
}

literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
{
auto tensor_shape = parse_tensor_shape(t);
auto dims = tensor_shape.lens();

Check warning on line 524 in src/onnx/onnx_parser.cpp

View workflow job for this annotation

GitHub Actions / tidy

the variable 'dims' is copy-constructed from a const reference but is only used as const reference; consider making it a const reference [performance-unnecessary-copy-initialization,-warnings-as-errors]
auto type = tensor_shape.type();
auto external_data = t.external_data();

if(not external_data.empty())
{
const std::string& data_file = external_data.at(0).value();
Expand Down Expand Up @@ -526,6 +552,7 @@
std::string s(raw_buffer.begin(), raw_buffer.end());
return create_literal(type, dims, s.data());
}

if(t.has_raw_data())
{
const std::string& s = t.raw_data();
Expand All @@ -535,16 +562,25 @@
switch(t.data_type())
{
case onnx::TensorProto::BOOL: return create_literal(shape::bool_type, dims, t.int32_data());

// INT4 or UINT4 operate as 8-bit buffers:
case onnx::TensorProto::INT4: return create_literal(shape::int8_type, dims, t.int32_data());
case onnx::TensorProto::UINT4: return create_literal(shape::uint8_type, dims, t.int32_data());

Check warning on line 568 in src/onnx/onnx_parser.cpp

View check run for this annotation

Codecov / codecov/patch

src/onnx/onnx_parser.cpp#L567-L568

Added lines #L567 - L568 were not covered by tests

case onnx::TensorProto::INT8: return create_literal(shape::int8_type, dims, t.int32_data());
case onnx::TensorProto::UINT8: return create_literal(shape::uint8_type, dims, t.int32_data());

case onnx::TensorProto::INT16: return create_literal(shape::int16_type, dims, t.int32_data());
case onnx::TensorProto::UINT16: return create_literal(shape::uint16_type, dims, t.int32_data());

case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::UINT32:
return create_literal(shape::uint32_type, dims, t.uint64_data());

case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::UINT64:
return create_literal(shape::uint64_type, dims, t.uint64_data());

case onnx::TensorProto::FLOAT16: {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
Expand All @@ -554,9 +590,12 @@
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return create_literal(shape::half_type, dims, data_half);
}

case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data());

case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());

case onnx::TensorProto::FLOAT8E4M3FNUZ: {
std::vector<int32_t> data_int32(t.int32_data().begin(), t.int32_data().end());
std::vector<migraphx::fp8::fp8e4m3fnuz> data_fp8;
Expand All @@ -566,6 +605,7 @@
[](float raw_val) { return migraphx::fp8::fp8e4m3fnuz{raw_val}; });
return create_literal(shape::fp8e4m3fnuz_type, dims, data_fp8);
}

case onnx::TensorProto::FLOAT8E5M2FNUZ:
case onnx::TensorProto::FLOAT8E5M2:
case onnx::TensorProto::FLOAT8E4M3FN:
Expand All @@ -576,6 +616,7 @@
}
MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type");
}

shape onnx_parser::parse_type(const onnx::TypeProto& t) const
{
shape::type_t shape_type = get_type(t.tensor_type().elem_type());
Expand Down Expand Up @@ -646,6 +687,8 @@
"incorrect final outputs\n";
return shape::fp8e4m3fnuz_type;
}
case 21: return shape::uint8_type;
case 22: return shape::int8_type;

Check warning on line 691 in src/onnx/onnx_parser.cpp

View check run for this annotation

Codecov / codecov/patch

src/onnx/onnx_parser.cpp#L690-L691

Added lines #L690 - L691 were not covered by tests
case 14:
case 15:
case 16:
Expand Down
Loading