From 8b18dc4566a983dfaf30c024a4f69d4e24df2f60 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Wed, 11 Sep 2024 09:20:24 -0500 Subject: [PATCH] Revert "parse onnx graph for int4 (#3408)" This reverts commit 3aa03ff5c22603b26297ff371440a99540cf0f93. --- src/onnx/onnx_parser.cpp | 53 ++++------------------------------------ 1 file changed, 5 insertions(+), 48 deletions(-) diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 102760fefc1..07d7f6a52d7 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -320,11 +320,6 @@ void print_added_instructions(module* mod, mod->debug_print(added_instructions); } -static bool is_type_packed_int4(const onnx::TensorProto& t) -{ - return t.data_type() == onnx::TensorProto::INT4 or t.data_type() == onnx::TensorProto::UINT4; -} - std::unordered_map parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto& graph) { @@ -334,13 +329,7 @@ parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto& if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{})) std::cout << "initializer: " << f.name() << std::endl; // backup instructions in parent mod - auto pt = parser.parse_tensor(f); - auto lit = mod->add_literal(pt); - - if(is_type_packed_int4(f)) - lit = mod->add_instruction(migraphx::make_op("unpack_int4"), lit); - - mod_insts[f.name()] = lit; + mod_insts[f.name()] = mod->add_literal(parser.parse_tensor(f)); if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{})) mod->debug_print(mod_insts[f.name()]); } @@ -373,7 +362,7 @@ parse_inputs(const onnx_parser& parser, if(parser.map_input_dims.count(name) > 0) { std::vector dims = parser.map_input_dims.at(name); - s = parser.parse_type(input.type(), dims); + s = parser.parse_type(input.type(), dims); } else if(parser.map_dyn_input_dims.count(name) > 0) { @@ -504,27 +493,12 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type())); } -static shape parse_tensor_shape(const onnx::TensorProto& t) -{ - std::vector dims(t.dims().begin(), t.dims().end()); - if(is_type_packed_int4(t)) - { - auto dim_n = dims.back(); - if(dim_n > 0 and (dim_n % 2 == 0)) - dims.back() = dim_n / 2; // int4-packed dimension converted to int8-sized units - else - MIGRAPHX_THROW("Int4: currently supports only even-sized packed tensors"); - } - return shape{get_type(t.data_type()), dims}; -} - literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const { - auto tensor_shape = parse_tensor_shape(t); - auto dims = tensor_shape.lens(); - auto type = tensor_shape.type(); + std::vector dims(t.dims().begin(), t.dims().end()); + auto type = get_type(t.data_type()); + shape tensor_shape(type, dims); auto external_data = t.external_data(); - if(not external_data.empty()) { const std::string& data_file = external_data.at(0).value(); @@ -552,7 +526,6 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const std::string s(raw_buffer.begin(), raw_buffer.end()); return create_literal(type, dims, s.data()); } - if(t.has_raw_data()) { const std::string& s = t.raw_data(); @@ -562,25 +535,16 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const switch(t.data_type()) { case onnx::TensorProto::BOOL: return create_literal(shape::bool_type, dims, t.int32_data()); - - // INT4 or UINT4 operate as 8-bit buffers: - case onnx::TensorProto::INT4: return create_literal(shape::int8_type, dims, t.int32_data()); - case onnx::TensorProto::UINT4: return create_literal(shape::uint8_type, dims, t.int32_data()); - case onnx::TensorProto::INT8: return create_literal(shape::int8_type, dims, t.int32_data()); case onnx::TensorProto::UINT8: return create_literal(shape::uint8_type, dims, t.int32_data()); - case onnx::TensorProto::INT16: return create_literal(shape::int16_type, dims, t.int32_data()); case onnx::TensorProto::UINT16: return create_literal(shape::uint16_type, dims, t.int32_data()); - case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, t.int32_data()); case onnx::TensorProto::UINT32: return create_literal(shape::uint32_type, dims, t.uint64_data()); - case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data()); case onnx::TensorProto::UINT64: return create_literal(shape::uint64_type, dims, t.uint64_data()); - case onnx::TensorProto::FLOAT16: { std::vector data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector data_half; @@ -590,12 +554,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const [](uint16_t raw_val) { return *reinterpret_cast(&raw_val); }); return create_literal(shape::half_type, dims, data_half); } - case onnx::TensorProto::DOUBLE: return create_literal(shape::double_type, dims, t.double_data()); - case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data()); - case onnx::TensorProto::FLOAT8E4M3FNUZ: { std::vector data_int32(t.int32_data().begin(), t.int32_data().end()); std::vector data_fp8; @@ -605,7 +566,6 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const [](float raw_val) { return migraphx::fp8::fp8e4m3fnuz{raw_val}; }); return create_literal(shape::fp8e4m3fnuz_type, dims, data_fp8); } - case onnx::TensorProto::FLOAT8E5M2FNUZ: case onnx::TensorProto::FLOAT8E5M2: case onnx::TensorProto::FLOAT8E4M3FN: @@ -616,7 +576,6 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const } MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type"); } - shape onnx_parser::parse_type(const onnx::TypeProto& t) const { shape::type_t shape_type = get_type(t.tensor_type().elem_type()); @@ -687,8 +646,6 @@ shape::type_t get_type(int dtype) "incorrect final outputs\n"; return shape::fp8e4m3fnuz_type; } - case 21: return shape::uint8_type; - case 22: return shape::int8_type; case 14: case 15: case 16: