diff --git a/src/api/api.cpp b/src/api/api.cpp index 4ecd0763225..519d29c9828 100644 --- a/src/api/api.cpp +++ b/src/api/api.cpp @@ -98,6 +98,10 @@ shape::type_t to_shape_type(migraphx_shape_datatype_t t) switch(t) { case migraphx_shape_tuple_type: return shape::tuple_type; + + // case migraphx_shape_uint4_type: return shape::uint4_type; + // case migraphx_shape_int4_type: return shape::int4_type; + #define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \ case migraphx_shape_##x: return shape::x; MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) @@ -111,10 +115,17 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t) switch(t) { case shape::tuple_type: return migraphx_shape_tuple_type; + + // case shape::uint4_type: return migraphx_shape_uint4_type; + // case shape::int4_type: return migraphx_shape_int4_type; + #define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \ case shape::x: return migraphx_shape_##x; MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) #undef MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT + case shape::uint4_type: + case shape::int4_type: + break; } MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type"); } diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 0c1e7b269d4..68df86d6eea 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -68,7 +68,9 @@ struct MIGRAPHX_EXPORT shape #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, enum type_t { - MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) tuple_type + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) tuple_type, + uint4_type, + int4_type }; #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES @@ -381,6 +383,8 @@ struct MIGRAPHX_EXPORT shape { switch(t) { + case uint4_type: + case int4_type: case tuple_type: { tv(); return; diff --git a/src/shape.cpp b/src/shape.cpp index f9a42361465..2a9ee00a4f2 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -122,7 +122,7 @@ struct shape_impl { if(not m_dyn_dims.empty()) { - auto maxes = max_lens(); + auto maxes = max_lens(); std::size_t max_val = std::numeric_limits::max(); return std::accumulate( @@ -224,7 +224,9 @@ const std::vector& shape::types() { static const std::vector result = { #define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x, - MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type}; + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type, + int4_type, + uint4_type}; return result; } @@ -233,6 +235,8 @@ std::string shape::name(shape::type_t t) switch(t) { case tuple_type: return "tuple_type"; + case int4_type: return "int4_type"; + case uint4_type: return "uint4_type"; #define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \ case x: return #x; MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE) @@ -246,6 +250,8 @@ std::string shape::cpp_type(shape::type_t t) switch(t) { case tuple_type: MIGRAPHX_THROW("No C++ type for tuple"); + case int4_type: MIGRAPHX_THROW("No C++ type for int4_type"); + case uint4_type: MIGRAPHX_THROW("No C++ type for uint4_type"); #define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \ case x: return #t; MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE) @@ -728,7 +734,11 @@ shape::type_t shape::parse_type(const std::string& s) #define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x}, MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type", tuple_type}, - {"tuple", tuple_type}}; + {"tuple", tuple_type}, + {"int4_type", int4_type}, + {"int4", int4_type}, + {"uint4_type", uint4_type}, + {"uint4", uint4_type}}; return m.at(s); } diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index a257fdba0e2..8257c340b2a 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -69,6 +69,8 @@ rocblas_datatype get_type(shape::type_t type) case shape::uint16_type: case shape::int16_type: case shape::int64_type: + case shape::int4_type: + case shape::uint4_type: case shape::uint64_type: MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!"); } diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp index 249a5aea9fa..7b6f66b636f 100644 --- a/test/gpu/jit.cpp +++ b/test/gpu/jit.cpp @@ -350,7 +350,11 @@ TEST_CASE(compile_math) auto vec_sizes = {2, 4, 6}; for(auto&& t : migraphx::shape::types()) { - if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) + if(contains({migraphx::shape::bool_type, + migraphx::shape::tuple_type, + migraphx::shape::int4_type, + migraphx::shape::uint4_type}, + t)) continue; auto name = migraphx::shape::cpp_type(t); if(t == migraphx::shape::half_type) @@ -403,6 +407,8 @@ TEST_CASE(assert_type_min_max) for(auto&& t : migraphx::shape::types()) { if(contains({migraphx::shape::bool_type, + migraphx::shape::uint4_type, + migraphx::shape::int4_type, migraphx::shape::fp8e4m3fnuz_type, migraphx::shape::tuple_type}, t))