Skip to content

Commit

Permalink
Revert "parse onnx graph for int4 (#3408)"
Browse files Browse the repository at this point in the history
This reverts commit 3aa03ff.
  • Loading branch information
pfultz2 committed Sep 11, 2024
1 parent 3aa03ff commit 8b18dc4
Showing 1 changed file with 5 additions and 48 deletions.
53 changes: 5 additions & 48 deletions src/onnx/onnx_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,6 @@ void print_added_instructions(module* mod,
mod->debug_print(added_instructions);
}

static bool is_type_packed_int4(const onnx::TensorProto& t)
{
return t.data_type() == onnx::TensorProto::INT4 or t.data_type() == onnx::TensorProto::UINT4;
}

std::unordered_map<std::string, instruction_ref>
parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto& graph)
{
Expand All @@ -334,13 +329,7 @@ parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto&
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
std::cout << "initializer: " << f.name() << std::endl;
// backup instructions in parent mod
auto pt = parser.parse_tensor(f);
auto lit = mod->add_literal(pt);

if(is_type_packed_int4(f))
lit = mod->add_instruction(migraphx::make_op("unpack_int4"), lit);

mod_insts[f.name()] = lit;
mod_insts[f.name()] = mod->add_literal(parser.parse_tensor(f));
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
mod->debug_print(mod_insts[f.name()]);
}
Expand Down Expand Up @@ -373,7 +362,7 @@ parse_inputs(const onnx_parser& parser,
if(parser.map_input_dims.count(name) > 0)
{
std::vector<std::size_t> dims = parser.map_input_dims.at(name);
s = parser.parse_type(input.type(), dims);
s = parser.parse_type(input.type(), dims);
}
else if(parser.map_dyn_input_dims.count(name) > 0)
{
Expand Down Expand Up @@ -504,27 +493,12 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type()));
}

static shape parse_tensor_shape(const onnx::TensorProto& t)
{
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
if(is_type_packed_int4(t))
{
auto dim_n = dims.back();
if(dim_n > 0 and (dim_n % 2 == 0))
dims.back() = dim_n / 2; // int4-packed dimension converted to int8-sized units
else
MIGRAPHX_THROW("Int4: currently supports only even-sized packed tensors");
}
return shape{get_type(t.data_type()), dims};
}

literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
{
auto tensor_shape = parse_tensor_shape(t);
auto dims = tensor_shape.lens();
auto type = tensor_shape.type();
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
auto type = get_type(t.data_type());
shape tensor_shape(type, dims);
auto external_data = t.external_data();

if(not external_data.empty())
{
const std::string& data_file = external_data.at(0).value();
Expand Down Expand Up @@ -552,7 +526,6 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
std::string s(raw_buffer.begin(), raw_buffer.end());
return create_literal(type, dims, s.data());
}

if(t.has_raw_data())
{
const std::string& s = t.raw_data();
Expand All @@ -562,25 +535,16 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
switch(t.data_type())
{
case onnx::TensorProto::BOOL: return create_literal(shape::bool_type, dims, t.int32_data());

// INT4 or UINT4 operate as 8-bit buffers:
case onnx::TensorProto::INT4: return create_literal(shape::int8_type, dims, t.int32_data());
case onnx::TensorProto::UINT4: return create_literal(shape::uint8_type, dims, t.int32_data());

case onnx::TensorProto::INT8: return create_literal(shape::int8_type, dims, t.int32_data());
case onnx::TensorProto::UINT8: return create_literal(shape::uint8_type, dims, t.int32_data());

case onnx::TensorProto::INT16: return create_literal(shape::int16_type, dims, t.int32_data());
case onnx::TensorProto::UINT16: return create_literal(shape::uint16_type, dims, t.int32_data());

case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::UINT32:
return create_literal(shape::uint32_type, dims, t.uint64_data());

case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::UINT64:
return create_literal(shape::uint64_type, dims, t.uint64_data());

case onnx::TensorProto::FLOAT16: {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
Expand All @@ -590,12 +554,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return create_literal(shape::half_type, dims, data_half);
}

case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data());

case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());

case onnx::TensorProto::FLOAT8E4M3FNUZ: {
std::vector<int32_t> data_int32(t.int32_data().begin(), t.int32_data().end());
std::vector<migraphx::fp8::fp8e4m3fnuz> data_fp8;
Expand All @@ -605,7 +566,6 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
[](float raw_val) { return migraphx::fp8::fp8e4m3fnuz{raw_val}; });
return create_literal(shape::fp8e4m3fnuz_type, dims, data_fp8);
}

case onnx::TensorProto::FLOAT8E5M2FNUZ:
case onnx::TensorProto::FLOAT8E5M2:
case onnx::TensorProto::FLOAT8E4M3FN:
Expand All @@ -616,7 +576,6 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
}
MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type");
}

shape onnx_parser::parse_type(const onnx::TypeProto& t) const
{
shape::type_t shape_type = get_type(t.tensor_type().elem_type());
Expand Down Expand Up @@ -687,8 +646,6 @@ shape::type_t get_type(int dtype)
"incorrect final outputs\n";
return shape::fp8e4m3fnuz_type;
}
case 21: return shape::uint8_type;
case 22: return shape::int8_type;
case 14:
case 15:
case 16:
Expand Down

0 comments on commit 8b18dc4

Please sign in to comment.