From fdbbefa11eaedc973fa5390356f6dee5dd2428ee Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Mon, 3 Apr 2023 09:36:16 -0700 Subject: [PATCH] Fix Python dtype conversion for int64 on Windows. (#12880) Fixes https://github.com/openxla/iree/issues/11080. The int64 and uint64 test cases here were failing on Windows as the element type mapping was routing via the code `l`, which is a "C long int" - not an explicitly 64 bit type. This changes the mapping to always use the explicit "type strings" (any string in `numpy.sctypeDict.keys()`, [shown in this gist](https://gist.github.com/ScottTodd/ec1f7906e9c644eb47f74280d6c26229)). Relates to https://github.com/openxla/iree/pull/12872 --- build_tools/cmake/ctest_all.sh | 2 - runtime/bindings/python/hal.cc | 37 +++++++++++-------- .../bindings/python/tests/vm_types_test.py | 19 ++++++---- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/build_tools/cmake/ctest_all.sh b/build_tools/cmake/ctest_all.sh index b255e77e49b5..cba6b09aec31 100755 --- a/build_tools/cmake/ctest_all.sh +++ b/build_tools/cmake/ctest_all.sh @@ -90,8 +90,6 @@ if [[ "$OSTYPE" =~ ^msys ]]; then "iree/tests/e2e/tensor_ops/check_vmvx_ukernel_local-task_unpack.mlir" # TODO(#11070): Fix argument/result signature mismatch "iree/tests/e2e/tosa_ops/check_vmvx_local-sync_microkernels_fully_connected.mlir" - # TODO(#11080): Fix arrays not matching in test_variant_list_buffers - "iree/runtime/bindings/python/vm_types_test" ) elif [[ "$OSTYPE" =~ ^darwin ]]; then excluded_tests+=( diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc index 110f121be0a8..f970a709931c 100644 --- a/runtime/bindings/python/hal.cc +++ b/runtime/bindings/python/hal.cc @@ -411,54 +411,59 @@ HalDevice HalDriver::CreateDeviceByURI(std::string& device_uri, namespace { py::object MapElementTypeToDType(iree_hal_element_type_t element_type) { - // See: https://docs.python.org/3/c-api/arg.html#numbers - // TODO: Handle dtypes that do not map to a code (i.e. fp16). - const char* dtype_code; + // See: + // * https://numpy.org/doc/stable/reference/arrays.dtypes.html + // * https://docs.python.org/3/c-api/arg.html#numbers + // + // Single letter codes can be ambiguous across platforms, so prefer explicit + // bit depth values, ("Type strings: Any string in numpy.sctypeDict.keys()"). + // See https://github.com/pybind/pybind11/issues/1908 + const char* dtype_string; switch (element_type) { case IREE_HAL_ELEMENT_TYPE_BOOL_8: - dtype_code = "?"; + dtype_string = "?"; break; case IREE_HAL_ELEMENT_TYPE_INT_8: case IREE_HAL_ELEMENT_TYPE_SINT_8: - dtype_code = "b"; + dtype_string = "int8"; break; case IREE_HAL_ELEMENT_TYPE_UINT_8: - dtype_code = "B"; + dtype_string = "uint8"; break; case IREE_HAL_ELEMENT_TYPE_INT_16: case IREE_HAL_ELEMENT_TYPE_SINT_16: - dtype_code = "h"; + dtype_string = "int16"; break; case IREE_HAL_ELEMENT_TYPE_UINT_16: - dtype_code = "H"; + dtype_string = "uint16"; break; case IREE_HAL_ELEMENT_TYPE_INT_32: case IREE_HAL_ELEMENT_TYPE_SINT_32: - dtype_code = "i"; + dtype_string = "int32"; break; case IREE_HAL_ELEMENT_TYPE_UINT_32: - dtype_code = "I"; + dtype_string = "uint32"; break; case IREE_HAL_ELEMENT_TYPE_INT_64: case IREE_HAL_ELEMENT_TYPE_SINT_64: - dtype_code = "l"; + dtype_string = "int64"; break; case IREE_HAL_ELEMENT_TYPE_UINT_64: - dtype_code = "L"; + dtype_string = "uint64"; break; case IREE_HAL_ELEMENT_TYPE_FLOAT_16: - dtype_code = "e"; + dtype_string = "float16"; break; case IREE_HAL_ELEMENT_TYPE_FLOAT_32: - dtype_code = "f"; + dtype_string = "float32"; break; case IREE_HAL_ELEMENT_TYPE_FLOAT_64: - dtype_code = "d"; + dtype_string = "float64"; break; default: throw RaiseValueError("Unsupported VM Buffer -> numpy dtype mapping"); } - return py::dtype(dtype_code); + return py::dtype(dtype_string); } } // namespace diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py index 782142e15f2e..671002649aaa 100644 --- a/runtime/bindings/python/tests/vm_types_test.py +++ b/runtime/bindings/python/tests/vm_types_test.py @@ -49,12 +49,18 @@ def test_variant_list_i64(self): def test_variant_list_buffers(self): device = rt.get_device("local-sync") ET = rt.HalElementType - for dt, et in ((np.int8, ET.SINT_8), (np.int16, ET.SINT_16), - (np.int32, ET.SINT_32), (np.int64, ET.SINT_64), - (np.uint8, ET.UINT_8), (np.uint16, ET.UINT_16), - (np.uint32, ET.UINT_32), (np.uint64, ET.UINT_64), - (np.float32, ET.FLOAT_32), (np.float64, ET.FLOAT_64)): - # TODO: Unimplemented: (np.float16, ET.FLOAT_16) + for dt, et in ( + (np.int8, ET.SINT_8), # + (np.int16, ET.SINT_16), # + (np.int32, ET.SINT_32), # + (np.int64, ET.SINT_64), # + (np.uint8, ET.UINT_8), # + (np.uint16, ET.UINT_16), # + (np.uint32, ET.UINT_32), # + (np.uint64, ET.UINT_64), # + (np.float16, ET.FLOAT_16), # + (np.float32, ET.FLOAT_32), # + (np.float64, ET.FLOAT_64)): lst = rt.VmVariantList(5) ary1 = np.asarray([1, 2, 3, 4], dtype=dt) bv1 = device.allocator.allocate_buffer_copy( @@ -65,7 +71,6 @@ def test_variant_list_buffers(self): lst.push_ref(bv1) ary2 = rt.DeviceArray(device, lst.get_as_object(0, rt.HalBufferView), - override_dtype=dt, implicit_host_transfer=True) np.testing.assert_array_equal(ary1, ary2) with self.assertRaises(IndexError):