diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 07d7f6a52d7..102760fefc1 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -320,6 +320,11 @@ void print_added_instructions(module* mod, mod->debug_print(added_instructions); } +static bool is_type_packed_int4(const onnx::TensorProto& t) +{ + return t.data_type() == onnx::TensorProto::INT4 or t.data_type() == onnx::TensorProto::UINT4; +} + std::unordered_map parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto& graph) { @@ -329,7 +334,13 @@ parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto& if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{})) std::cout << "initializer: " << f.name() << std::endl; // backup instructions in parent mod - mod_insts[f.name()] = mod->add_literal(parser.parse_tensor(f)); + auto pt = parser.parse_tensor(f); + auto lit = mod->add_literal(pt); + + if(is_type_packed_int4(f)) + lit = mod->add_instruction(migraphx::make_op("unpack_int4"), lit); + + mod_insts[f.name()] = lit; if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{})) mod->debug_print(mod_insts[f.name()]); } @@ -362,7 +373,7 @@ parse_inputs(const onnx_parser& parser, if(parser.map_input_dims.count(name) > 0) { std::vector dims = parser.map_input_dims.at(name); - s = parser.parse_type(input.type(), dims); + s = parser.parse_type(input.type(), dims); } else if(parser.map_dyn_input_dims.count(name) > 0) { @@ -493,12 +504,27 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type())); } -literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const +static shape parse_tensor_shape(const onnx::TensorProto& t) { std::vector dims(t.dims().begin(), t.dims().end()); - auto type = get_type(t.data_type()); - shape tensor_shape(type, dims); + if(is_type_packed_int4(t)) + { + auto dim_n = dims.back(); + if(dim_n > 0 and (dim_n % 2 == 0)) + dims.back() = dim_n / 2; // int4-packed dimension converted to int8-sized units + else + MIGRAPHX_THROW("Int4: currently supports only even-sized packed tensors"); + } + return shape{get_type(t.data_type()), dims}; +} + +literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const +{ + auto tensor_shape = parse_tensor_shape(t); + auto dims = tensor_shape.lens(); + auto type = tensor_shape.type(); auto external_data = t.external_data(); + if(not external_data.empty()) { const std::string& data_file = external_data.at(0).value(); @@ -526,6 +552,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const std::string s(raw_buffer.begin(), raw_buffer.end()); return create_literal(type, dims, s.data()); } + if(t.has_raw_data()) { const std::string& s = t.raw_data(); @@ -535,16 +562,25 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const switch(t.data_type()) { case onnx::TensorProto::BOOL: return create_literal(shape::bool_type, dims, t.int32_data()); + + // INT4 or UINT4 operate as 8-bit buffers: + case onnx::TensorProto::INT4: return create_literal(shape::int8_type, dims, t.int32_data()); + case onnx::TensorProto::UINT4: return create_literal(shape::uint8_type, dims, t.int32_data()); + case onnx::TensorProto::INT8: return create_literal(shape::int8_type, dims, t.int32_data()); case onnx::TensorProto::UINT8: return create_literal(shape::uint8_type, dims, t.int32_data()); + case onnx::TensorProto::INT16: return create_literal(shape::int16_type, dims, t.int32_data()); case onnx::TensorProto::UINT16: return create_literal(shape::uint16_type, dims, t.int32_data()); + case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, t.int32_data()); case onnx::TensorProto::UINT32: return create_literal(shape::uint32_type, dims, t.uint64_data()); + case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data()); case onnx::TensorProto::UINT64: return create_literal(shape::uint64_type, dims, t.uint64_data()); + case onnx::TensorProto::FLOAT16: { std::vector data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector data_half; @@ -554,9 +590,12 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const [](uint16_t raw_val) { return *reinterpret_cast(&raw_val); }); return create_literal(shape::half_type, dims, data_half); } + case onnx::TensorProto::DOUBLE: return create_literal(shape::double_type, dims, t.double_data()); + case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data()); + case onnx::TensorProto::FLOAT8E4M3FNUZ: { std::vector data_int32(t.int32_data().begin(), t.int32_data().end()); std::vector data_fp8; @@ -566,6 +605,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const [](float raw_val) { return migraphx::fp8::fp8e4m3fnuz{raw_val}; }); return create_literal(shape::fp8e4m3fnuz_type, dims, data_fp8); } + case onnx::TensorProto::FLOAT8E5M2FNUZ: case onnx::TensorProto::FLOAT8E5M2: case onnx::TensorProto::FLOAT8E4M3FN: @@ -576,6 +616,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const } MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type"); } + shape onnx_parser::parse_type(const onnx::TypeProto& t) const { shape::type_t shape_type = get_type(t.tensor_type().elem_type()); @@ -646,6 +687,8 @@ shape::type_t get_type(int dtype) "incorrect final outputs\n"; return shape::fp8e4m3fnuz_type; } + case 21: return shape::uint8_type; + case 22: return shape::int8_type; case 14: case 15: case 16: