From e4b0d6331d93e2872dec90fac201f2b96aa8d709 Mon Sep 17 00:00:00 2001 From: Toby Roseman Date: Mon, 18 Sep 2023 14:38:22 -0700 Subject: [PATCH] 7.0 Release (#1977) --- cmake/coreml-utils.cmake | 4 +- .../mil/_deployment_compatibility.py | 2 - .../converters/mil/backend/mil/helper.py | 19 +- .../converters/mil/backend/mil/load.py | 3 +- .../mil/backend/mil/passes/test_passes.py | 6 +- coremltools/converters/mil/converter.py | 2 +- .../mil/frontend/milproto/__init__.py | 2 +- .../converters/mil/frontend/milproto/load.py | 48 ++- .../mil/frontend/tensorflow/test/test_load.py | 4 +- .../mil/frontend/torch/converter.py | 11 +- .../mil/frontend/torch/internal_graph.py | 5 +- .../converters/mil/frontend/torch/ops.py | 95 +++++- .../mil/frontend/torch/quantization_ops.py | 97 +++++-- .../mil/frontend/torch/test/test_torch_ops.py | 94 ++++++ .../torch/test/test_torch_quantization_ops.py | 97 ++++++- coremltools/converters/mil/mil/block.py | 8 +- coremltools/converters/mil/mil/input_type.py | 11 +- coremltools/converters/mil/mil/operation.py | 8 +- .../converters/mil/mil/ops/defs/_utils.py | 87 ++++++ .../converters/mil/mil/ops/defs/iOS15/conv.py | 8 +- .../mil/mil/ops/defs/iOS15/linear.py | 7 +- .../mil/mil/ops/defs/iOS15/reduction.py | 4 +- .../mil/mil/ops/defs/iOS16/constexpr_ops.py | 43 +-- coremltools/converters/mil/mil/ops/helper.py | 13 +- .../mil/mil/ops/tests/iOS14/test_linear.py | 18 ++ .../mil/mil/ops/tests/iOS14/test_reduction.py | 5 + .../tests/iOS14/test_tensor_transformation.py | 6 +- .../mil/ops/tests/iOS16/test_constexpr_ops.py | 30 +- .../mil/mil/ops/tests/iOS16/test_conv.py | 77 +++++ .../mil/mil/ops/tests/test_utils.py | 54 +++- .../defs/cleanup/const_deduplication.py | 79 +++-- .../defs/cleanup/remove_redundant_ops.py | 70 ++++- .../mil/mil/passes/defs/optimize_conv.py | 4 + .../mil/mil/passes/defs/optimize_linear.py | 6 +- .../converters/mil/mil/passes/helper.py | 54 +--- .../mil/mil/passes/tests/test_passes.py | 274 ++++++++++-------- coremltools/converters/mil/mil/program.py | 38 ++- .../converters/mil/mil/tests/test_block.py | 6 +- .../converters/mil/mil/tests/test_programs.py | 73 ++++- .../converters/mil/mil/types/type_int.py | 7 +- .../converters/mil/mil/types/type_tensor.py | 7 +- coremltools/converters/mil/mil/var.py | 22 ++ .../coreml/_post_training_quantization.py | 29 +- .../optimize/coreml/_quantization_passes.py | 33 ++- .../test/ml_program/test_compression.py | 16 +- .../test/optimize/coreml/test_passes.py | 6 +- .../coreml/test_post_training_quantization.py | 4 +- coremltools/version.py | 2 +- docs-guides/make.bat | 70 ++--- milstoragepython/MilStoragePython.cpp | 4 +- .../docs/Format/ItemSimilarityRecommender.rst | 18 +- 51 files changed, 1237 insertions(+), 453 deletions(-) create mode 100644 coremltools/converters/mil/mil/ops/tests/iOS16/test_conv.py diff --git a/cmake/coreml-utils.cmake b/cmake/coreml-utils.cmake index a87a3bfeb..ab4679c85 100644 --- a/cmake/coreml-utils.cmake +++ b/cmake/coreml-utils.cmake @@ -37,7 +37,7 @@ function(coreml_add_build_proto proto_fn target_suffix) ${CMAKE_CURRENT_BINARY_DIR}/format/${proto_fn}_enum.h COMMENT "Generating c++ enums from ${proto_fn}.proto into ${CMAKE_CURRENT_BINARY_DIR}/format/" COMMAND ${CMAKE_BINARY_DIR}/deps/protobuf/cmake/protoc - --plugin=protoc-gen-enum=mlmodel${target_suffix}/enumgen + --plugin=protoc-gen-enum=mlmodel/enumgen --enum_out=${CMAKE_CURRENT_BINARY_DIR}/format/ -I${CMAKE_CURRENT_SOURCE_DIR}/format/ ${CMAKE_CURRENT_SOURCE_DIR}/format/${proto_fn}.proto @@ -77,7 +77,7 @@ function(coreml_add_build_proto proto_fn target_suffix) add_custom_target(tgt_${proto_fn}_enums ALL COMMENT "Generating c++ enums from ${proto_fn}.proto into ${CMAKE_CURRENT_SOURCE_DIR}/build/format/" COMMAND ${CMAKE_BINARY_DIR}/deps/protobuf/cmake/protoc - --plugin=protoc-gen-enum=mlmodel${target_suffix}/enumgen + --plugin=protoc-gen-enum=mlmodel/enumgen --enum_out=${CMAKE_CURRENT_SOURCE_DIR}/build/format/ -I${CMAKE_CURRENT_SOURCE_DIR}/format/ ${CMAKE_CURRENT_SOURCE_DIR}/format/${proto_fn}.proto diff --git a/coremltools/converters/mil/_deployment_compatibility.py b/coremltools/converters/mil/_deployment_compatibility.py index d5e5bc6e0..db0122111 100644 --- a/coremltools/converters/mil/_deployment_compatibility.py +++ b/coremltools/converters/mil/_deployment_compatibility.py @@ -23,8 +23,6 @@ class AvailableTarget(IntEnum): iOS17 = _SPECIFICATION_VERSION_IOS_17 # macOS versions (aliases of iOS versions) - macOS15 = _SPECIFICATION_VERSION_IOS_13 - macOS16 = _SPECIFICATION_VERSION_IOS_14 macOS10_15 = _SPECIFICATION_VERSION_IOS_13 macOS10_16 = _SPECIFICATION_VERSION_IOS_14 macOS11 = _SPECIFICATION_VERSION_IOS_14 diff --git a/coremltools/converters/mil/backend/mil/helper.py b/coremltools/converters/mil/backend/mil/helper.py index 880f4bda1..9a88b4fc9 100644 --- a/coremltools/converters/mil/backend/mil/helper.py +++ b/coremltools/converters/mil/backend/mil/helper.py @@ -19,6 +19,9 @@ from coremltools.converters.mil.mil.types.type_mapping import np_val_to_py_type from coremltools.models.utils import _WEIGHTS_DIR_NAME, _WEIGHTS_FILE_NAME +# For immediate values, those types are stored in bytes (MIL parser reads those types from bytes). +IMMEDIATE_VALUE_TYPES_IN_BYTES = (types.fp16, types.int8, types.uint8, types.uint32) + def create_valuetype_scalar(data_type): """ @@ -105,7 +108,7 @@ def _tensor_field_by_type(tensor_val, builtin_type): elif types.is_int(builtin_type): if builtin_type == types.int64 or builtin_type == types.uint64: return tensor_val.longInts.values - if builtin_type in (types.int8, types.uint8, types.uint32): + if builtin_type in IMMEDIATE_VALUE_TYPES_IN_BYTES: return tensor_val.bytes.values if builtin_type == types.int16 or builtin_type == types.uint16: # TODO (rdar://111797203): Serialize to byte after MIL changes to read from byte field. @@ -132,7 +135,7 @@ def _set_empty_tensor_field_by_type(tensor_val, builtin_type): elif types.is_int(builtin_type): if (builtin_type == types.int64 or builtin_type == types.uint64): tensor_val.longInts.SetInParent() - elif builtin_type in (types.int8, types.uint8, types.uint32): + elif builtin_type in IMMEDIATE_VALUE_TYPES_IN_BYTES: tensor_val.bytes.SetInParent() else: tensor_val.ints.SetInParent() @@ -167,7 +170,7 @@ def create_tensor_value(np_tensor): if builtin_type == types.str: for x in np.nditer(np_tensor): t_field.append(x.encode("utf-8")) - elif builtin_type in (types.fp16, types.int8, types.uint8, types.uint32): + elif builtin_type in IMMEDIATE_VALUE_TYPES_IN_BYTES: val.immediateValue.tensor.bytes.values = np_val_to_py_type(np_tensor) else: for x in np_tensor.flatten(): @@ -189,7 +192,7 @@ def create_scalar_value(py_scalar): # Set the tensor value t_field = _tensor_field_by_type(t_val, builtin_type) - if builtin_type in (types.fp16, types.int8, types.uint8, types.uint32): + if builtin_type in IMMEDIATE_VALUE_TYPES_IN_BYTES: # Serialize to bytes because MIL read them from the "bytes" field in TensorValue. val.immediateValue.tensor.bytes.values = np_val_to_py_type(py_scalar) else: @@ -295,7 +298,7 @@ def types_to_proto(valuetype): return create_valuetype_scalar(types_to_proto_primitive(valuetype)) -def create_file_value(output_var, blob_writer): +def _get_offset_by_writing_data(output_var, blob_writer): if output_var.val.dtype.kind == 'f' and output_var.val.dtype.itemsize == 4: offset = blob_writer.write_float_data(np.ascontiguousarray(output_var.val.flatten())) elif output_var.val.dtype.kind == "f" and output_var.val.dtype.itemsize == 2: @@ -316,6 +319,12 @@ def create_file_value(output_var, blob_writer): else: raise TypeError("Unsupported type, {}, for net buffer serialization.".format(output_var.val.dtype)) + return offset + + +def create_file_value(output_var, blob_writer): + offset = _get_offset_by_writing_data(output_var, blob_writer) + return create_file_value_tensor( file_name=os.path.join(os.path.join('@model_path', _WEIGHTS_DIR_NAME), _WEIGHTS_FILE_NAME), offset=offset, diff --git a/coremltools/converters/mil/backend/mil/load.py b/coremltools/converters/mil/backend/mil/load.py index 36691868b..8f2c9d2ed 100644 --- a/coremltools/converters/mil/backend/mil/load.py +++ b/coremltools/converters/mil/backend/mil/load.py @@ -42,7 +42,8 @@ try: from coremltools.libmilstoragepython import _BlobStorageWriter as BlobWriter -except: +except Exception as e: + logger.warning(f"Fail to import BlobWriter from libmilstoragepython. {e}") BlobWriter = None diff --git a/coremltools/converters/mil/backend/mil/passes/test_passes.py b/coremltools/converters/mil/backend/mil/passes/test_passes.py index 07e83648f..e1b8c5de7 100644 --- a/coremltools/converters/mil/backend/mil/passes/test_passes.py +++ b/coremltools/converters/mil/backend/mil/passes/test_passes.py @@ -1088,13 +1088,13 @@ def program(x): x = mb.pow(x=x, y=2.0) x = mb.sqrt(x=x) x = mb.reduce_argmax(x=x) - x = mb.reshape(x=x, shape=[*x_shape]) + x = mb.reshape(x=x, shape=[*x_shape[:-1]]) else: x = mb.mul(x=x, y=x) x = mb.sqrt(x=x) x = mb.pow(x=x, y=2.0) x = mb.reduce_argmax(x=x) - x = mb.reshape(x=x, shape=[*x_shape]) + x = mb.reshape(x=x, shape=[*x_shape[:-1]]) return x prev_prog, _, block = apply_pass_and_basic_check( @@ -1108,5 +1108,5 @@ def program(x): program=program, inputs={"x": x_shape}, backend=("mlprogram", "fp32"), - expected_output_shapes={block.outputs[0].name: tuple(x_shape)}, + expected_output_shapes={block.outputs[0].name: tuple(x_shape[:-1])}, ) diff --git a/coremltools/converters/mil/converter.py b/coremltools/converters/mil/converter.py index 9242d4354..3fae1cb94 100644 --- a/coremltools/converters/mil/converter.py +++ b/coremltools/converters/mil/converter.py @@ -288,7 +288,7 @@ def mil_convert_to_proto( PassPipelineManager.apply_pipeline(prog, main_pipeline) - prog._check_invalid_tensor_rank() + prog._check_invalid_program() if convert_to == 'milinternal': return None, prog diff --git a/coremltools/converters/mil/frontend/milproto/__init__.py b/coremltools/converters/mil/frontend/milproto/__init__.py index 34ab79f0b..3c3c0069c 100644 --- a/coremltools/converters/mil/frontend/milproto/__init__.py +++ b/coremltools/converters/mil/frontend/milproto/__init__.py @@ -3,4 +3,4 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -from .load import load +from . import load diff --git a/coremltools/converters/mil/frontend/milproto/load.py b/coremltools/converters/mil/frontend/milproto/load.py index f7c3508b4..054ef871c 100644 --- a/coremltools/converters/mil/frontend/milproto/load.py +++ b/coremltools/converters/mil/frontend/milproto/load.py @@ -8,16 +8,22 @@ import numpy as np from coremltools import _logger as logger -from coremltools.converters.mil._deployment_compatibility import \ - AvailableTarget as _target +from coremltools.converters.mil._deployment_compatibility import AvailableTarget as _target +from coremltools.converters.mil.backend.mil import helper from coremltools.converters.mil.mil import Block from coremltools.converters.mil.mil import Builder as mb -from coremltools.converters.mil.mil import (Function, ListVar, Placeholder, - Program, TupleInputType, Var, - mil_list, types) +from coremltools.converters.mil.mil import ( + Function, + ListVar, + Placeholder, + Program, + TupleInputType, + Var, + mil_list, + types, +) from coremltools.converters.mil.mil.block import curr_block -from coremltools.converters.mil.mil.ops.registry import \ - SSAOpRegistry as _SSAOpRegistry +from coremltools.converters.mil.mil.ops.registry import SSAOpRegistry as _SSAOpRegistry from coremltools.proto import MIL_pb2 as pm from coremltools.proto import Model_pb2 as ml @@ -25,7 +31,8 @@ try: from coremltools.libmilstoragepython import _BlobStorageReader as BlobReader -except: +except Exception as e: + logger.warning(f"Fail to import BlobReader from libmilstoragepython. {e}") BlobReader = None @@ -145,7 +152,7 @@ def _load_value(context, value_spec): else: value = _load_file_value(context, value_spec.blobFileValue, dtype) - if dtype in (types.fp16, types.int8, types.uint8, types.uint32): + if dtype in helper.IMMEDIATE_VALUE_TYPES_IN_BYTES: value = np.frombuffer(value, types.nptype_from_builtin(dtype)).reshape( shape ) @@ -246,20 +253,23 @@ def _dummy_false_fn(*loop_vars): inputs["_false_fn"] = _dummy_false_fn +def _load_const_op(context, op_spec): + inputs = {k: _load_value(context, v) for k, v in op_spec.attributes.items()} + pymil_var = getattr(mb, op_spec.type)(**inputs) + context.register_var_with_name(op_spec.outputs[0].name, pymil_var) + + def _load_operation(context, op_spec): if not isinstance(op_spec, pm.Operation): raise TypeError("Invalid Operation spec object") op_type = op_spec.type - if op_type == "const" or op_type.startswith("constexpr_"): + if op_type == "const" or "constexpr_" in op_type: if op_spec.blocks: raise ValueError("const / constexpr operation can't have any block") if op_spec.inputs: raise ValueError("const / constexpr operation can't have any input") - - inputs = {k: _load_value(context, v) for k, v in op_spec.attributes.items()} - pymil_var = getattr(mb, op_type)(**inputs) - context.register_var_with_name(op_spec.outputs[0].name, pymil_var) + _load_const_op(context, op_spec) else: if op_type == "custom_layer": @@ -402,11 +412,19 @@ def _load_function(context, func_spec, spec_version): def load(model_spec, specification_version, file_weights_dir="", **kwargs): + """ + Load MILProto to Pymil. + + Set force_spec_version to force override the spec version. + """ if not isinstance(model_spec, ml.Model): raise TypeError("Invalid Model sepc object") if specification_version < model_spec.specificationVersion: - raise ValueError("specification_version must be greater or equal to the input model spec version") + if not kwargs.get("force_spec_version", False): + raise ValueError( + "specification_version must be greater or equal to the input model spec version" + ) if model_spec.WhichOneof("Type") != "mlProgram": raise ValueError("Only MIL proto based mlmodels can be loaded") diff --git a/coremltools/converters/mil/frontend/tensorflow/test/test_load.py b/coremltools/converters/mil/frontend/tensorflow/test/test_load.py index cc4930e65..159066515 100644 --- a/coremltools/converters/mil/frontend/tensorflow/test/test_load.py +++ b/coremltools/converters/mil/frontend/tensorflow/test/test_load.py @@ -158,7 +158,7 @@ def build_model(x): @pytest.mark.parametrize( "target", - [ct.target.iOS13, ct.target.macOS15, ct.target.watchOS6, ct.target.tvOS13], + [ct.target.iOS13, ct.target.macOS10_15, ct.target.watchOS6, ct.target.tvOS13], ) def test_invalid_deployment_target_cumsum(self, target): x_shape = (3, 4, 5) @@ -179,7 +179,7 @@ def build_model(x): @pytest.mark.parametrize( "target", - [ct.target.iOS14, ct.target.macOS16, ct.target.watchOS7, ct.target.tvOS14], + [ct.target.iOS14, ct.target.macOS10_16, ct.target.watchOS7, ct.target.tvOS14], ) def test_valid_deployment_target_cumsum(self, target): x_shape = (3, 4, 5) diff --git a/coremltools/converters/mil/frontend/torch/converter.py b/coremltools/converters/mil/frontend/torch/converter.py index 98c279167..9e0be95c3 100644 --- a/coremltools/converters/mil/frontend/torch/converter.py +++ b/coremltools/converters/mil/frontend/torch/converter.py @@ -19,6 +19,7 @@ from .._utils import get_output_names from .internal_graph import InternalTorchIRGraph, InternalTorchIRNode from .ops import convert_nodes +from .quantization_ops import _dequantized_weight from .torch_op_registry import _TORCH_OPS_REGISTRY from .torchir_passes import ( flatten_graph_input_values, @@ -147,8 +148,6 @@ def get_dequantized_var(self, name: str, dequantized_name: str = None): # the MIL op. if dequantized_name is not None: self._context.add(original_var, dequantized_name) - if self._quant_dtype is None: - raise AssertionError("Trying to dequantize without quantization info") return original_var, self._quant_dtype quant_params = self.get_quantization_info(name) @@ -429,6 +428,10 @@ def convert_const(self): if isinstance(val, torch._C.ScriptObject): logger.info(f"Encountered constant {name} of type _torch._C.ScriptObject") continue + elif isinstance(val, torch.Tensor) and val.is_quantized: + const = _dequantized_weight(val.cpu(), name) + self.context.add(const) + continue elif not isinstance(val, np.ndarray): raise ValueError(f"unsupported class for {name} in PyTorch graph: {type(val)}") # TODO (rdar://107718371): support uint8 quantization @@ -623,10 +626,10 @@ def _lower_graph_block(graph): if is_tensor or is_quantized_tensor: if is_tensor and prefix in state_dict: assert torch.equal( - module, state_dict[prefix] + module.cpu(), state_dict[prefix].cpu() ), "tensor value not consistent between torch ir and state_dict" if prefix in params_dict: - assert torch.equal(module, params_dict[prefix]) + assert torch.equal(module.cpu(), params_dict[prefix].cpu()) replace_input[_output] = first_node_with_prefix[prefix] else: params_dict[prefix] = module diff --git a/coremltools/converters/mil/frontend/torch/internal_graph.py b/coremltools/converters/mil/frontend/torch/internal_graph.py index 76633f874..b6fd83507 100644 --- a/coremltools/converters/mil/frontend/torch/internal_graph.py +++ b/coremltools/converters/mil/frontend/torch/internal_graph.py @@ -269,7 +269,10 @@ def __init__( # Add params for name, param in params_dict.items(): if isinstance(param, torch.Tensor): - value = param.detach().cpu().numpy() + if param.is_quantized: + value = param + else: + value = param.detach().cpu().numpy() else: value = param self.params[name] = value diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index f56df369a..14c8bfc73 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -236,6 +236,35 @@ def _list_select(shape_var, index): res = mb.gather(x=shape_var, indices=index) return res +def _is_const(var, optional=False): + """ + Check if a var is a const. + It could be `const` or `constexpr_` ops. + """ + if optional and var is None: + return True + if isinstance(var, np.ndarray): + return True + return var is not None and (var.val is not None or var.op.op_type.startswith("constexpr_")) + +def _create_linear_layer(x, w, bias): + """ + Utility to translate linear layer. + Since the linear layer can only take `const` or `constexpr_` weight as input, + for other cases, we implement the linear layer through matmul. + + For instance, given a torch model with an int8 weight: + + int8_weight -> transpose -> reshape -> linear + + If we directly use `mb.linear`, it is going to produce compilation error at the runtime. + """ + if _is_const(w) and _is_const(bias, optional=True): + return mb.linear(x=x, weight=w, bias=bias) + res = mb.matmul(x=x, y=w, transpose_y=True) + if bias is not None: + res = mb.add(x=res, y=bias) + return res def _construct_constant(val, name): # Converter cannot handle torch tensors. @@ -854,9 +883,10 @@ def linear(context, node): inputs = _get_inputs(context, node, expected=[2, 3]) x = inputs[0] W = inputs[1] + x, W = promote_input_dtypes([x, W]) bias = inputs[2] if len(node.inputs) == 3 else None - res = mb.linear(x=x, weight=W, bias=bias, name=node.name) - context.add(res) + res = _create_linear_layer(x, W, bias) + context.add(res, torch_name=node.name) @register_torch_op(torch_alias=["conv2d"]) @@ -1157,6 +1187,7 @@ def relu6(context, node): @register_torch_op def einsum(context, node): vars = context[node.inputs[1]] + vars = promote_input_dtypes(vars) equation = context[node.inputs[0]].val x = build_einsum_mil(vars, equation, node.name) context.add(x) @@ -1574,7 +1605,6 @@ def view(context, node): if ( isinstance(shape, list) and all([isinstance(dim, Var) and len(dim.shape) == 0 for dim in shape]) - and any([dim.val is None for dim in shape]) ): shape = mb.concat(values=shape, axis=0) @@ -1860,7 +1890,7 @@ def group_norm(context, node): x = mb.mul(x=x,y=weight) if bias is not None: bias = mb.reshape(x=bias, shape=bias_shape) - x = mb.add(x=x,y=bias) + x = mb.add(x=x, y=bias) context.add(x,node.name) @@ -3256,24 +3286,38 @@ def select(context, node): inputs = _get_inputs(context, node, expected=3) _input = inputs[0] dim = inputs[1].val - index = inputs[2].val + index = inputs[2] assert dim.shape == () - assert index.shape == () # NOTE: # Each index in @begin_array/@end_array corresponds to a dimension of @_input # Each val of those arrays corresponds to the start/end index to slice in that dimension rank = _input.rank + begin_array = [0] * rank - begin_array[dim] = index + if index.val is None: + # index value not known till runtime + begin_array[dim] = index + begin_array = mb.concat(values=begin_array, axis=0) + else: + # index value known now + assert index.val.shape == () + begin_array[dim] = index.val + end_array = [s if isinstance(s, int) else 0 for s in _input.shape] end_mask = [True] * rank squeeze_mask = [False] * rank squeeze_mask[dim] = True - if index != -1: - end_array[dim] = index + 1 + if index.val != -1: + if index.val is None: + # index value not know till runtime + temp = mb.add(x=index, y=1) + end_array[dim] = temp + end_array = mb.concat(values=end_array, axis=0) + else: + end_array[dim] = index.val + 1 end_mask[dim] = False slice_by_index = mb.slice_by_index( @@ -3326,7 +3370,9 @@ def _get_slice_params(context, data, inputs): for i in range(num_of_slice_set): if inputs[3 * i + 1] is None: # This is pure index select - idx = context[inputs[3 * i]].val + idx = context[inputs[3 * i]] + if idx.val is not None: + idx = idx.val begin[i] = idx squeeze_mask[i] = True else: @@ -3768,7 +3814,7 @@ def rand(context, node): @register_torch_op def randn(context, node): - inputs = _get_inputs(context, node, expected=5) + inputs = _get_inputs(context, node, expected=[5, 6]) shape = inputs[0] rand_normal = mb.random_normal(shape=shape) rand_fp32 = mb.cast(x=rand_normal, dtype="fp32", name=node.name) @@ -3811,6 +3857,17 @@ def bitwise_and(context, node): ) +@register_torch_op +def logical_not(context, node): + # There is an optional `out` parameter in torch.logical_not. + inputs = _get_inputs(context, node, expected=[1, 2]) + x = inputs[0] + if not types.is_bool(x.dtype): + x = mb.cast(x=x, dtype="bool") + res = mb.logical_not(x=x, name=node.name) + context.add(res) + + def _avg_pool(context, node, inputs): x = inputs[0] kernel_sizes = inputs[1] @@ -5547,7 +5604,7 @@ def baddbmm(context, node): context.add(bias) baddbmm_node = mb.add(x=bias, y=bmm_node, name=node.name) - context.add(baddbmm_node) + context.add(baddbmm_node) else: bmm_node.name = node.name context.add(bmm_node) @@ -6204,3 +6261,17 @@ def scaled_dot_product_attention(context, node): # multiply attn_weights and value tensor res = mb.matmul(x=attn_weights_normalized, y=v, name=node.name) context.add(res) + + +@register_torch_op +def fliplr(context, node): + """ + Flip tensor in the left/right direction. + + Flip the entries in each row in the left/right direction. Columns are preserved, but appear in a + different order than before. + It's equivalent to TF's reverse op but with axes always be [1]. + """ + x = _get_inputs(context, node, expected=1)[0] + res = mb.reverse(x=x, axes=[1], name=node.name) + context.add(res) diff --git a/coremltools/converters/mil/frontend/torch/quantization_ops.py b/coremltools/converters/mil/frontend/torch/quantization_ops.py index a59e54b3f..dd320c75e 100644 --- a/coremltools/converters/mil/frontend/torch/quantization_ops.py +++ b/coremltools/converters/mil/frontend/torch/quantization_ops.py @@ -8,9 +8,9 @@ from coremltools import _logger as logger from coremltools.converters.mil.mil import Builder as mb -from coremltools.converters.mil.mil import Var +from coremltools.converters.mil.mil import Var, types -from .ops import NUM_TO_TORCH_DTYPE, _get_inputs +from .ops import NUM_TO_TORCH_DTYPE, _create_linear_layer, _get_inputs, promote_input_dtypes from .torch_op_registry import register_torch_op TORCH_QTYPE_TO_NP_TYPE = {_torch.qint8: _np.int8, _torch.quint8: _np.uint8} @@ -30,10 +30,15 @@ def _quantize_general( scale = scale_var.val if scale is None: raise ValueError("quantization scale must be const at compile time") + if len(scale.shape) > 0 and _np.prod(scale.shape) == 1: + scale = scale.reshape(-1)[0] + axis = None zero_point = zero_point_var.val if zero_point is None: raise ValueError("quantization zero point must be const at compile time") + if len(zero_point.shape) > 0 and _np.prod(zero_point.shape) == 1: + zero_point = zero_point.reshape(-1)[0] torch_dtype = NUM_TO_TORCH_DTYPE.get(torch_dtype_var.val) if torch_dtype is None: @@ -90,7 +95,7 @@ def dequantize(context, node): context.quant_context.get_dequantized_var(node.inputs[0], node.name) -def _dequantized_weight(qweight): +def _dequantized_weight(qweight, name:str = None): """ Given the first output (qweight) of torch.ops.quantized.conv2d/linear_unpack, this returns a dequantized version of the tensor to be added to the context. @@ -102,12 +107,15 @@ def _dequantized_weight(qweight): quantized_weights = _torch.int_repr(qweight).numpy() # Axis doesn't matter for per-tensor quantization. axis = _np.int32(0) - dequant_weights = mb.constexpr_affine_dequantize( - quantized_data=quantized_weights, - zero_point=zero_point, - scale=scale, - axis=axis, - ) + kwargs = { + "quantized_data": quantized_weights, + "zero_point": zero_point, + "scale": scale, + "axis": axis, + } + if name is not None: + kwargs["name"] = name + dequant_weights = mb.constexpr_affine_dequantize(**kwargs) # per_channel_affine_float_qparams is same as per_channel_affine except that it # expects both scale and zero point to be floating point values. elif qweight.qscheme() in {_torch.per_channel_affine, _torch.per_channel_affine_float_qparams}: @@ -133,12 +141,15 @@ def _dequantized_weight(qweight): quantized_weights = _torch.int_repr(qweight).numpy() # Axis doesn't matter for per-tensor quantization. axis = _np.int32(0) - dequant_weights = mb.constexpr_affine_dequantize( - quantized_data=quantized_weights, - zero_point=zero_point, - scale=scale, - axis=axis, - ) + kwargs = { + "quantized_data": quantized_weights, + "zero_point": zero_point, + "scale": scale, + "axis": axis, + } + if name is not None: + kwargs["name"] = name + dequant_weights = mb.constexpr_affine_dequantize(**kwargs) else: raise ValueError(f'Unsupported quant scheme "{qweight.qscheme()}"') return dequant_weights @@ -196,7 +207,7 @@ def _process_conv(context, node, add_relu=False): out_scale = context[node.inputs[2]] out_zero_point = context[node.inputs[3]].val - _ = context.quant_context.get_quantized_per_tensor( + context.quant_context.get_quantized_per_tensor( res.name, x_dtype, out_scale, out_zero_point, node.name ) @@ -209,24 +220,31 @@ def _process_linear(context, node, add_relu=False): # 4. output zero-point # Unpack PyTorch's packed params. - packed_params = context.torch_graph.params[node.inputs[1]] - qweight, bias = _torch.ops.quantized.linear_unpack(packed_params) - dequant_weights = _dequantized_weight(qweight) - context.add(dequant_weights) - # Bias can be fed as-is. - bias = bias.detach().numpy() + if node.inputs[1] not in context: + packed_params = context.torch_graph.params[node.inputs[1]] + qweight, bias = _torch.ops.quantized.linear_unpack(packed_params) + dequant_weights = _dequantized_weight(qweight) + context.add(dequant_weights) + bias = bias.detach().numpy() + else: + dequant_weights, bias = context[node.inputs[1]] x, x_dtype = context.quant_context.get_dequantized_var(node.inputs[0]) - res = mb.linear(x=x, weight=dequant_weights, bias=bias) + + x, dequant_weights = promote_input_dtypes([x, dequant_weights]) + res = _create_linear_layer(x, dequant_weights, bias) if add_relu: res = mb.relu(x=res) context.add(res) out_scale = context[node.inputs[2]] out_zero_point = context[node.inputs[3]].val - _ = context.quant_context.get_quantized_per_tensor( - res.name, x_dtype, out_scale, out_zero_point, node.name - ) + if out_scale.val != 0 or out_zero_point != 0: + context.quant_context.get_quantized_per_tensor( + res.name, x_dtype, out_scale, out_zero_point, node.name + ) + else: + context.add(res, node.name) def _process_binary(context, node, binary_op, add_relu=False): @@ -250,11 +268,36 @@ def _process_binary(context, node, binary_op, add_relu=False): out_scale = context[node.inputs[2]] out_zero_point = context[node.inputs[3]].val - _ = context.quant_context.get_quantized_per_tensor( + context.quant_context.get_quantized_per_tensor( res.name, lhs_dtype, out_scale, out_zero_point, node.name ) +@register_torch_op(torch_alias=["quantized::matmul"]) +def quantized_matmul(context, node): + inputs = _get_inputs(context, node, expected=4) + assert types.is_float(inputs[0].dtype) + assert types.is_float(inputs[1].dtype) + x, y = promote_input_dtypes([inputs[0], inputs[1]]) + assert ( + inputs[2].val == 0 and inputs[3].val == 0 + ), "non zero scale / zero-point not supported in quantized_matmul op." + res = mb.matmul(x=x, y=y, name=node.name) + context.add(res) + + +# Defines all the quantization-related nodes that are noOps +@register_torch_op( + torch_alias=[ + "quantized::linear_prepack", + ] +) +def quant_noop(context, node): + logger.info("Setting pytorch op: {} to no-op.".format(node)) + inputs = _get_inputs(context, node) + context.add(inputs, torch_name=node.name) + + @register_torch_op(torch_alias=["quantized::linear"]) def quantized_linear(context, node): _process_linear(context, node) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 1ce0087d4..bbf030d62 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -6870,6 +6870,34 @@ def forward(self, x): ) + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product(compute_units, backends) + ) + def test_dynamic_index(self, compute_unit, backend): + class M(torch.nn.Module): + def forward(self, float_arr, int_arr): + dynamic_index = int_arr[1] + float_arr[dynamic_index] = 12.95 + return float_arr + + a = torch.Tensor([1., 2., 4., 5]) + i = torch.Tensor([0, 1, 2]).long() + inputs_types=[ + ct.TensorType(name="a", shape=a.shape), + ct.TensorType(name="i", shape=i.shape, dtype=np.int32) + ] + + self.run_compare_torch( + [a, i], + M(), + input_as_shape=False, + converter_input_type=inputs_types, + backend=backend, + compute_unit=compute_unit + ) + + class TestNonZero(TorchBaseTest): @pytest.mark.parametrize( "compute_unit, backend, rank, as_tuple", @@ -9616,6 +9644,59 @@ def forward(self, x, y): ) +class TestLogicalNot(TorchBaseTest): + @pytest.mark.parametrize( + "compute_unit, backend, input_dtype", + itertools.product( + compute_units, + backends, + [torch.int32, torch.float32, torch.bool], + ), + ) + def test_logical_not(self, compute_unit, backend, input_dtype): + class TestModel(torch.nn.Module): + def forward(self, x): + return torch.logical_not(x) + + input_data = torch.randint( + low=0, high=2 if input_dtype == torch.bool else 4, size=(2, 3, 4), dtype=input_dtype + ) + self.run_compare_torch( + input_data, + TestModel(), + backend=backend, + compute_unit=compute_unit, + input_as_shape=False, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, input_dtype, output_dtype", + itertools.product( + compute_units, + backends, + [torch.int32, torch.float32, torch.bool], + [torch.int16, torch.float16, torch.bool], + ), + ) + def test_logical_not_with_out(self, compute_unit, backend, input_dtype, output_dtype): + class TestModel(torch.nn.Module): + def forward(self, x): + out_tensor = torch.empty((2, 3, 4), dtype=output_dtype) + torch.logical_not(x, out=out_tensor) + return out_tensor + + input_data = torch.randint( + low=0, high=2 if input_dtype == torch.bool else 4, size=(2, 3, 4), dtype=input_dtype + ) + self.run_compare_torch( + input_data, + TestModel(), + backend=backend, + compute_unit=compute_unit, + input_as_shape=False, + ) + + class TestUnfold(TorchBaseTest): @pytest.mark.parametrize( "compute_unit, backend, input_shape, kernel_size, padding, stride", @@ -9996,3 +10077,16 @@ def forward(self, x): model.eval() self.run_compare_torch((3, 32), model, backend=backend, compute_unit=compute_unit) + + +class TestFliplr(TorchBaseTest): + @pytest.mark.parametrize( + "compute_unit, backend, input_shape", + itertools.product(compute_units, backends, [(2, 3), (3, 4, 5), (8, 2, 6, 4)]), + ) + def test_fliplr(self, compute_unit, backend, input_shape): + class TestModel(nn.Module): + def forward(self, x): + return torch.fliplr(x) + + self.run_compare_torch(input_shape, TestModel(), backend=backend, compute_unit=compute_unit) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py index 91f2a5798..5a09606ff 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py @@ -18,6 +18,8 @@ MSG_TORCH_NOT_FOUND, MSG_TORCH_VISION_NOT_FOUND, ) +from coremltools.converters.mil.testing_utils import get_op_types_in_program + from .testing_utils import TorchBaseTest pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) @@ -27,6 +29,24 @@ torch.backends.quantized.engine = "qnnpack" +def _force_quantize_model(model, q_dtype): + """ + In torch, the quantized model can only be obtained from PTQ. + This utility allows us to produce an int8 quantized model. + """ + # modify the parameter to int8 + with torch.no_grad(): + for name, param in model.named_parameters(): + shape = param.shape + new_value = torch.quantize_per_tensor( + torch.rand(*shape), scale=1.0, zero_point=0, dtype=q_dtype + ) + param_cls = type(param) + kwargs = param.__dict__ + new_value = param_cls(new_value, requires_grad=False).to(torch.device("cpu")) + model._parameters[name] = new_value + return model + class TorchQuantizationBaseTest(TorchBaseTest): @staticmethod def run_compare_torch( @@ -38,7 +58,7 @@ def run_compare_torch( ): # TODO(rdar://108472419): properly design a random input if input_as_shape: - input_data = torch.ones(*input_data) + input_data = [torch.ones(*shape) for shape in input_data] return TorchBaseTest.run_compare_torch( input_data, @@ -56,15 +76,23 @@ def run_compare_torch( # TODO(rdar://107430678): test stand-alone quantize and dequantize when cast is ready class TestPyTorchQuantizationOps(TorchQuantizationBaseTest): @pytest.mark.parametrize( - "quant_dtype, input_rank, is_zp_present, zp_dtype", + "quant_dtype, input_rank, is_zp_present, zp_dtype, are_params_tensors", itertools.product( (torch.qint8, torch.quint8, torch.qint32), (1, 3, 5), (True, False), (np.int8, np.uint8, np.int32), + (True, False), ), ) - def test_quantize_dequantize_per_tensor(self, quant_dtype, input_rank, is_zp_present, zp_dtype): + def test_quantize_dequantize_per_tensor( + self, + quant_dtype, + input_rank, + is_zp_present, + zp_dtype, + are_params_tensors, + ): input_shape = [*np.random.randint(low=1, high=5, size=(input_rank,))] scale = np.random.rand() zero_point = 0 @@ -72,6 +100,9 @@ def test_quantize_dequantize_per_tensor(self, quant_dtype, input_rank, is_zp_pre low = 0 if quant_dtype == torch.quint8 or zp_dtype == np.uint8 else -128 high = 128 if quant_dtype == torch.qint8 or zp_dtype == np.int8 else 256 zero_point = np.random.randint(low, high, dtype=zp_dtype) + if are_params_tensors: + scale = torch.tensor([scale]) + zero_point = torch.tensor([zero_point]) class Model(torch.nn.Module): def forward(self, x): @@ -85,9 +116,9 @@ def forward(self, x): ValueError, match=r"MIL quantization dtype must be int8 or uint8", ): - self.run_compare_torch(input_shape, model) + self.run_compare_torch([input_shape], model) else: - self.run_compare_torch(input_shape, model, atol=5e-4, rtol=5e-4) + self.run_compare_torch([input_shape], model, atol=5e-4, rtol=5e-4) @pytest.mark.parametrize( "quant_dtype, input_rank, is_zp_present, zp_dtype", @@ -122,9 +153,9 @@ def forward(self, x): ValueError, match=r"MIL quantization dtype must be int8 or uint8", ): - self.run_compare_torch(input_shape, model) + self.run_compare_torch([input_shape], model) else: - self.run_compare_torch(input_shape, model, atol=5e-4, rtol=5e-4) + self.run_compare_torch([input_shape], model, atol=5e-4, rtol=5e-4) # TODO(rdar://108463675): refactor torch op tests later to parametrize quantized vs standard ops @@ -157,7 +188,7 @@ def forward(self, x): input_shape = (1, 3, 5) elif input_rank == 4: input_shape = (1, 2, 3, 5) - self.run_compare_torch(input_shape, model) + self.run_compare_torch([input_shape], model) @pytest.mark.parametrize( ",".join( @@ -234,7 +265,7 @@ def forward(self, x): model = Model() self.run_compare_torch( - (1, in_channels, height, width), + [(1, in_channels, height, width)], model, ) @@ -266,7 +297,7 @@ def forward(self, x): input_data = torch.from_numpy(input_data) model = EmbeddingModel() self.run_compare_torch( - input_data, model, input_as_shape=False, converter_input_type=converter_input_type + [input_data], model, input_as_shape=False, converter_input_type=converter_input_type ) # Tests for add, add_relu, mul @@ -290,8 +321,50 @@ def forward(self, x): return torch.dequantize(x) model = Model() + self.run_compare_torch([(2, 3)], model) + + @pytest.mark.xfail( + reason="torch.ops.quantized.matmul is not suporting mixed precision computation.", + strict=True, + ) + @pytest.mark.parametrize( + "quant_dtype", + [torch.quint8, torch.qint8], + ) + def test_quantized_matmul(self, quant_dtype): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.parameter.Parameter(torch.rand(5, 4)) + + def forward(self, x): + return torch.ops.quantized.matmul(x, self.weight, 0, 0) + + model = Model() + model = _force_quantize_model(model, q_dtype=quant_dtype) + input_shape = [(3, 5)] + self.run_compare_torch(input_shape, model) - self.run_compare_torch((2, 3), model) + @pytest.mark.parametrize( + "quant_dtype", + [torch.quint8, torch.qint8], + ) + def test_quantized_params(self, quant_dtype): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.parameter.Parameter(torch.rand(5, 4)) + + def forward(self, x): + dequanitized_weight = torch.dequantize(self.weight) + return torch.matmul(x, dequanitized_weight) + + model = Model() + model = _force_quantize_model(model, q_dtype=quant_dtype) + input_shape = [(3, 5)] + res = self.run_compare_torch(input_shape, model) + prog = res[1]._mil_program + assert get_op_types_in_program(prog) == ["constexpr_affine_dequantize", "matmul"] @pytest.mark.skipif(not _HAS_TORCH_VISION, reason=MSG_TORCH_VISION_NOT_FOUND) @@ -306,4 +379,4 @@ class TestTorchvisionQuantizedModels(TorchQuantizationBaseTest): def test_quantized_mobilenetv2(self): model = torchvision.models.quantization.mobilenet_v2(pretrained=True, quantize=True) - self.run_compare_torch((1, 3, 224, 224), model, atol=1.0) + self.run_compare_torch([(1, 3, 224, 224)], model, atol=1.0) diff --git a/coremltools/converters/mil/mil/block.py b/coremltools/converters/mil/mil/block.py index aef4e6f0c..399e65628 100644 --- a/coremltools/converters/mil/mil/block.py +++ b/coremltools/converters/mil/mil/block.py @@ -356,13 +356,11 @@ def _insert_op_before(self, new_op, before_op=None): if not isinstance(v, (Var, tuple)): continue vs = [v] if isinstance(v, Var) else v - for s in vs: - if not self.is_var_visible_in_block(s, upto_op_with_id=idx): + for v in vs: + if not self.is_var_visible_in_block(v, upto_op_with_id=idx): before_op_name = before_op.name if before_op is not None else "None" msg = "Op '{}' input {}={} is not in scope of {} before {}" - raise ValueError( - msg.format(new_op.name, k, s.name, self.name, before_op_name) - ) + raise ValueError(msg.format(new_op.name, k, v.name, self.name, before_op_name)) # add new_op if before_op is None: diff --git a/coremltools/converters/mil/mil/input_type.py b/coremltools/converters/mil/mil/input_type.py index 267037f51..f3b57e491 100644 --- a/coremltools/converters/mil/mil/input_type.py +++ b/coremltools/converters/mil/mil/input_type.py @@ -148,11 +148,12 @@ def validate_inputs(self, op_name, op_type, candidate_kvs): input_type = self.input_types[name] # Check constness # Don't check InternalInputType (so _const_symbolic can work) - if input_type.const and \ - not isinstance(input_type, InternalInputType) \ - and var.val is None: - msg = msg_prefix + \ - 'Input {} must be const at compile time' + if ( + input_type.const + and not isinstance(input_type, InternalInputType) + and not var.is_descendant_of_const + ): + msg = msg_prefix + "Input {} must be const at compile time" raise ValueError(msg.format(name), name, var.name) if not isinstance(var, InternalVar) and \ diff --git a/coremltools/converters/mil/mil/operation.py b/coremltools/converters/mil/mil/operation.py index 19157b035..5fdb6add8 100644 --- a/coremltools/converters/mil/mil/operation.py +++ b/coremltools/converters/mil/mil/operation.py @@ -403,6 +403,10 @@ def value_inference(self): in `self.input_var`. Return a builtin value (single output) or a tuple of builtin values (multi-outputs) of the same length as returned by ` type_inference` + + Please note that, for ``constexpr_`` (compression) ops, we implement + ``materialized_val_inference`` instead, so that we don't compute the actual + values for those ops, which might potentially results in memory issue. """ msg = "value_inference() is not implemented by op {}" raise NotImplementedError(msg.format(self.op_type)) @@ -487,8 +491,8 @@ def check_and_detach(v_new, v_old, op, no_check_var_types): # existing's sym_type. if not is_compatible_type(v_new.sym_type, v_old.sym_type) and not no_check_var_types: raise ValueError( - f"New var type `{types.builtin_to_string(v_new.sym_type)}` not a " - f"subtype of existing var type `{types.builtin_to_string(v_old.sym_type)}`." + f"New var type `{v_new.sym_type}` not a " + f"subtype of existing var type `{v_old.sym_type}`." ) v_old.remove_child_op(op, no_check_var_types) diff --git a/coremltools/converters/mil/mil/ops/defs/_utils.py b/coremltools/converters/mil/mil/ops/defs/_utils.py index 57fe572f9..fe2c1e074 100644 --- a/coremltools/converters/mil/mil/ops/defs/_utils.py +++ b/coremltools/converters/mil/mil/ops/defs/_utils.py @@ -547,3 +547,90 @@ def solve_slice_by_index_shape(x_shape, begin, end, stride, begin_mask, end_mask ret_shape.append(max(0, num)) return ret_shape + + +def pack_elements_into_bits(elements: np.ndarray, nbits: int) -> np.ndarray: + """ + Pack elements into nbits representation, by starting with the least significant bit (LSB) and + moving upward to the most significant bit (MSB). + + Returns packed elements as np.uint8. + """ + if not np.issubdtype(elements.dtype, np.integer): + raise ValueError(f"Only support packing integers elements, but got {elements.dtype}") + + # Adjust allowed value range based on if the input is signed or unsigned. + if np.issubdtype(elements.dtype, np.signedinteger): + max_val = 2 ** (nbits - 1) - 1 + min_val = -max_val - 1 + else: + max_val = 2**nbits - 1 + min_val = 0 + if np.max(elements) > max_val: + raise ValueError( + f"To pack elements into {nbits}-bit, the max value is {max_val}, but got {np.max(elements)}" + ) + if np.min(elements) < min_val: + raise ValueError( + f"To pack elements into {nbits}-bit, the min value is {min_val}, but got {np.min(elements)}" + ) + + # As np.unpackbits only supports uint8, convert to uint8 first. + # Notice that it will not lose information, because the bits are unchanged when converting int8 + # to uint8. For example, the signed int -6 has bit representation '11111010', and when we unpackbits + # we get [0, 1, 0, 1, 1, 1, 1, 1], where only first 4 elements are needed for 4-bit representation. + elements = elements.astype(np.uint8) + bitarray = np.unpackbits(elements.reshape(-1, 1), bitorder="little", axis=-1)[:, :nbits] + return np.packbits(bitarray.flatten(), bitorder="little") + + +def restore_elements_from_packed_bits( + packed_values: np.ndarray, nbits: int, element_num: int, are_packed_values_signed: bool = False +) -> np.ndarray: + """ + Restore elements from packed bits. Requires values that are packed by starting with the + least significant bit (LSB) and moving upward to the most significant bit (MSB), which is the + method used in `pack_elements_into_bits`. + + are_packed_values_signed: Indicates if the packed_values were packed from signed integers. If + True, the n-bit number unpacked from packed_values will be interpreted as signed integers, + and the returned ndarray will have dtype np.int8. Otherwise, np.uint8 will be used. + """ + if len(packed_values.shape) != 1: + raise NotImplementedError( + f"Only support 1-rank packed_values. But got {len(packed_values.shape)}" + ) + + if packed_values.dtype == np.int8: + # As np.unpackbits only supports uint8, need to convert first. + packed_values = packed_values.astype(np.uint8) + elif packed_values.dtype != np.uint8: + raise NotImplementedError( + f"Only support int8 or uint8 packed_values, but got {packed_values.dtype}" + ) + + bitarray = np.unpackbits(packed_values, bitorder="little") + pad_required = bitarray.size % nbits != 0 + if pad_required: + bitarray = np.concatenate([bitarray, np.zeros(nbits - bitarray.size % nbits)]).astype( + bitarray.dtype + ) + if bitarray.size % nbits != 0: + raise ValueError( + f"The length of bitarray ({bitarray.size}) should be divisible by " + f"nbits ({nbits})." + ) + bitarray = bitarray.reshape(-1, nbits)[:element_num, :] + # The np.packbits doesn't work well for signed int if we feed `bitarray` to it directly. + # For example, the original signed int is -6, which is packed as 1010 for 4-bit representation, + # and here `bitarray` is [[0, 1, 0, 1]], where the value will be interpreted as 10 (b'1010') + # by np.packbits. + # To make np.packbits work correctly, we need to repeat the sign bit. For example, 1010 will + # become 11111010, where np.packbits can correctly handle and after converting to int8 it's -6. + if are_packed_values_signed: + # Repeat the sign bit to make uint8 to int8 works. + bitarray = np.repeat(bitarray, [1] * (nbits - 1) + [8 - nbits + 1], axis=1) + restored_elements = np.packbits(bitarray, bitorder="little", axis=-1).reshape(-1) + if are_packed_values_signed: + restored_elements = restored_elements.astype(np.int8) + return restored_elements diff --git a/coremltools/converters/mil/mil/ops/defs/iOS15/conv.py b/coremltools/converters/mil/mil/ops/defs/iOS15/conv.py index ee0ffd800..2fc2e0840 100644 --- a/coremltools/converters/mil/mil/ops/defs/iOS15/conv.py +++ b/coremltools/converters/mil/mil/ops/defs/iOS15/conv.py @@ -157,10 +157,9 @@ def type_inference(self): C_in = self.x.shape[1] groups = self.groups.val - if self.bias is not None and \ - (len(self.bias.val.shape) > 1 or self.bias.val.shape[0] != C_out): + if self.bias is not None and (len(self.bias.shape) > 1 or self.bias.shape[0] != C_out): msg = "# of bias values {} not equal to # output channels {}" - raise ValueError(msg.format(self.bias.val.shape[0], C_out)) + raise ValueError(msg.format(self.bias.shape[0], C_out)) if C_in % groups != 0: msg = "# of input channels {} not divisible by groups {}" raise ValueError(msg.format(C_in, groups)) @@ -179,7 +178,8 @@ def type_inference(self): # Ignore self.pad if pad_type != custom custom_pad = None if self.pad_type.val != 'custom' else self.pad.val - if self.weight.val is None and any([True if d > 1 else False for d in dilations]): + is_weight_dynamic = not self.weight.is_descendant_of_const + if is_weight_dynamic and any([True if d > 1 else False for d in dilations]): raise ValueError("Convolution with dynamic weights does not support dilations!") N = inshape[0] diff --git a/coremltools/converters/mil/mil/ops/defs/iOS15/linear.py b/coremltools/converters/mil/mil/ops/defs/iOS15/linear.py index 87479ab84..ffa806739 100644 --- a/coremltools/converters/mil/mil/ops/defs/iOS15/linear.py +++ b/coremltools/converters/mil/mil/ops/defs/iOS15/linear.py @@ -17,6 +17,7 @@ from coremltools.converters.mil.mil.operation import VALUE from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op from coremltools.converters.mil.mil.ops.defs._utils import broadcast_shapes, parse_einsum_equation +from coremltools.converters.mil.mil.types import nptype_from_builtin from coremltools.converters.mil.mil.types.symbolic import is_symbolic @@ -57,7 +58,7 @@ class linear(Operation): def default_inputs(self): Dout = self.weight.shape[0] return DefaultInputs( - bias=[0.]*Dout, + bias=np.array([0.0] * Dout, dtype=nptype_from_builtin(self.x.dtype)), ) def type_inference(self): @@ -75,10 +76,10 @@ def type_inference(self): raise ValueError(msg.format(self.name, x_shape[-1], weight_shape[-1])) if self.bias is not None: assert len(self.bias.shape) == 1 - if len(self.bias.val) != weight_shape[-2]: + if self.bias.shape[0] != weight_shape[-2]: msg = "Op '{}' (linear op): Size of the bias, which is {}, " \ "does not match the first dimension of weights, which is {}" - raise ValueError(msg.format(self.name, len(self.bias.val), weight_shape[-2])) + raise ValueError(msg.format(self.name, self.bias.shape[0], weight_shape[-2])) shape = list(x_shape) shape[-1] = weight_shape[0] return types.tensor(x_type, tuple(shape)) diff --git a/coremltools/converters/mil/mil/ops/defs/iOS15/reduction.py b/coremltools/converters/mil/mil/ops/defs/iOS15/reduction.py index 7748904f5..321fd6c11 100644 --- a/coremltools/converters/mil/mil/ops/defs/iOS15/reduction.py +++ b/coremltools/converters/mil/mil/ops/defs/iOS15/reduction.py @@ -10,6 +10,7 @@ TensorInputType) from coremltools.converters.mil.mil.operation import VALUE from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op +from coremltools.converters.mil.mil.types import nptype_from_builtin class ReductionAxes(Operation): @@ -57,7 +58,8 @@ def type_inference(self): @precondition(allow=VALUE) def value_inference(self): axes = tuple(self.axes.val) if self.axes is not None else None - return self.get_operator()(self.x.val, axis=axes, keepdims=self.keep_dims.val) + res = self.get_operator()(self.x.val, axis=axes, keepdims=self.keep_dims.val) + return res.astype(nptype_from_builtin(self.x.dtype)) def get_operator(self): raise NotImplementedError() diff --git a/coremltools/converters/mil/mil/ops/defs/iOS16/constexpr_ops.py b/coremltools/converters/mil/mil/ops/defs/iOS16/constexpr_ops.py index 9b2953fc7..2dfb47775 100644 --- a/coremltools/converters/mil/mil/ops/defs/iOS16/constexpr_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/iOS16/constexpr_ops.py @@ -6,10 +6,10 @@ import numpy as np from coremltools.converters.mil.mil import types -from coremltools.converters.mil.mil.input_type import (InputSpec, - TensorInputType) +from coremltools.converters.mil.mil.input_type import InputSpec, TensorInputType from coremltools.converters.mil.mil.operation import Operation from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op +from coremltools.converters.mil.mil.ops.defs._utils import restore_elements_from_packed_bits from coremltools.converters.mil.mil.ops.defs.iOS16 import _IOS16_TARGET @@ -34,7 +34,7 @@ class constexpr_affine_dequantize(Operation): quantized_data: const tensor (Required) zero_point: const tensor (Required) - * ``zero_point`` can be either a scalar or a vector. + * ``zero_point`` can be either a scalar or a vector. * ``zero_point`` follows similar broadcasting rules and size constraints as ``scale``. scale: const tensor (Required) @@ -109,12 +109,9 @@ def assert_vector_size_same_as_axial_dimension(param, axis_dim_size, name): shape = self.quantized_data.shape return types.tensor(dtype, shape) - def value_inference(self): + def materialized_val_inference(self): return self.decompress( - self.quantized_data.val, - self.zero_point.val, - self.scale.val, - self.axis.val + self.quantized_data.val, self.zero_point.val, self.scale.val, self.axis.val ) @staticmethod @@ -172,14 +169,14 @@ def type_inference(self): shape = self.source_val.shape return types.tensor(dtype, shape) - def value_inference(self): + def materialized_val_inference(self): return np.float32(self.source_val.val) @register_op(opset_version=_IOS16_TARGET) class constexpr_lut_to_dense(Operation): """ - A compile-time operation that returns a constant output value upon decompressing + A compile-time operation that returns a constant output value upon decompressing a look-up table (LUT) to a dense tensor. This operation is used to store constant weights in a LUT format (also known as @@ -195,9 +192,8 @@ class constexpr_lut_to_dense(Operation): shape: const tensor (Required) - Notes - ----- - + Notes + ----- * Any data is packed and read in a row-major order. * ``NUM_PALETTES`` can be one of ``{2, 4, 16, 64 or 256}``. * ``n_bits = log2(NUM_PALETTES)`` can thus be one of ``{1, 2, 4, 6, 8}``. @@ -232,7 +228,7 @@ class constexpr_lut_to_dense(Operation): lut=TensorInputType(const=True, type_domain="T"), shape=TensorInputType(const=True, type_domain=types.uint32), ) - + type_domains = { "T": (types.int8, types.uint8, types.fp16, types.fp32) } @@ -261,7 +257,7 @@ def assert_is_vector(param, name): shape = self.shape.val return types.tensor(dtype, shape) - def value_inference(self): + def materialized_val_inference(self): return self.decompress( self.lut.val, self.indices.val, @@ -270,19 +266,8 @@ def value_inference(self): @staticmethod def decompress(lut, indices, shape): - bitarray = np.unpackbits(indices, bitorder="little") nbits = np.log2(lut.size).astype(np.int32) - - pad_required = bitarray.size % nbits != 0 - if pad_required: - bitarray = np.concatenate([bitarray, np.zeros(nbits - bitarray.size % nbits)]).astype(bitarray.dtype) - - assert bitarray.size % nbits == 0 - - size = np.prod(shape) - bitarray = bitarray.reshape(-1, nbits)[:size, :] - - indices = np.packbits(bitarray, bitorder="little", axis=-1).reshape(-1) + indices = restore_elements_from_packed_bits(indices, nbits, np.prod(shape)) flatten_val = lut[indices] return flatten_val.reshape(shape) @@ -339,7 +324,7 @@ class constexpr_sparse_to_dense(Operation): mask=TensorInputType(const=True, type_domain=types.uint8), shape=TensorInputType(const=True, type_domain=types.uint32), ) - + type_domains = { "T": (types.int8, types.uint8, types.fp16, types.fp32) } @@ -371,7 +356,7 @@ def assert_is_vector(param, name): shape = self.shape.val return types.tensor(dtype, shape) - def value_inference(self): + def materialized_val_inference(self): return self.decompress(self.nonzero_data.val, self.mask.val, self.shape.val) @staticmethod diff --git a/coremltools/converters/mil/mil/ops/helper.py b/coremltools/converters/mil/mil/ops/helper.py index ea65697f4..a456725dd 100644 --- a/coremltools/converters/mil/mil/ops/helper.py +++ b/coremltools/converters/mil/mil/ops/helper.py @@ -3,8 +3,15 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +from typing import Dict, Type -def _get_version_of_op(op_variants, opset_version): +from coremltools.converters.mil._deployment_compatibility import AvailableTarget +from coremltools.converters.mil.mil.operation import Operation + + +def _get_version_of_op( + op_variants: Dict[AvailableTarget, Type[Operation]], opset_version: AvailableTarget +) -> Type[Operation]: """ A utility function that retrieves an op cls given a dictionary of op variants and target version """ @@ -13,6 +20,10 @@ def _get_version_of_op(op_variants, opset_version): opset_versions.sort() if opset_version is None: op_cls = op_variants[opset_versions[0]] + elif opset_version > opset_versions[-1] and opset_version > AvailableTarget.iOS17: + # TODO(rdar://111114658): Remove when no longer required. + # Inherit ops from the latest opset by default. + op_cls = op_variants[opset_versions[-1]] else: if opset_version not in op_variants: op_type = list(op_variants.values())[0].__name__ diff --git a/coremltools/converters/mil/mil/ops/tests/iOS14/test_linear.py b/coremltools/converters/mil/mil/ops/tests/iOS14/test_linear.py index 00ce6956c..c3fafa42c 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS14/test_linear.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS14/test_linear.py @@ -13,6 +13,7 @@ from coremltools.converters.mil.mil import types from coremltools.converters.mil.mil.ops.tests.iOS14 import backends from coremltools.converters.mil.mil.ops.tests.testing_utils import run_compare_builder +from coremltools.converters.mil.mil.types import builtin_to_string, nptype_from_builtin from coremltools.converters.mil.testing_reqs import compute_units from coremltools.converters.mil.testing_utils import random_gen, ssa_fn @@ -103,6 +104,23 @@ def build(x): backend=backend, ) + @pytest.mark.parametrize( + "compute_unit, backend, input_type", + itertools.product(compute_units, backends, [types.int32, types.fp16, types.fp32]), + ) + def test_default_bias_type(self, compute_unit, backend, input_type): + # Test the default bias matches the dtype of x + @mb.program( + input_specs=[mb.TensorSpec(shape=(1, 2), dtype=types.fp32)], + opset_version=backend.opset_version, + ) + def prog(x): + x = mb.cast(x=x, dtype=builtin_to_string(input_type)) + weight = np.random.rand(3, 2).astype(nptype_from_builtin(input_type)) + res = mb.linear(x=x, weight=weight) + assert res.op.bias.val.dtype == nptype_from_builtin(input_type) + return res + class TestMatMul: @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) diff --git a/coremltools/converters/mil/mil/ops/tests/iOS14/test_reduction.py b/coremltools/converters/mil/mil/ops/tests/iOS14/test_reduction.py index 54bac467c..60ecf6b7c 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS14/test_reduction.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS14/test_reduction.py @@ -254,10 +254,15 @@ def test_reduce_min(): @ssa_fn def test_reduce_prod(): + # test value res = mb.reduce_prod(x=x_val, axes=[axis], keep_dims=keep_dims).val ref = np.prod(x_val, axis=axis, keepdims=keep_dims) np.testing.assert_allclose(ref, res, atol=1e-04, rtol=1e-05) + # test dtype for int input + res = mb.reduce_prod(x=x_val.astype(np.int32), axes=[axis], keep_dims=keep_dims).val + assert res.dtype == np.int32 + @ssa_fn def test_reduce_sum(): res = mb.reduce_sum(x=x_val, axes=[axis], keep_dims=keep_dims).val diff --git a/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_transformation.py b/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_transformation.py index 93f21ac53..ab24a6eea 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_transformation.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_transformation.py @@ -1022,8 +1022,7 @@ def test_builder_eval(self): np.testing.assert_allclose(ans[idx], v[idx].val, atol=1e-04, rtol=1e-05) @staticmethod - @pytest.mark.skipif(ct.utils._macos_version() < (14, 0), - reason="Bug fixed in macOS 14") + @pytest.mark.skipif(ct.utils._macos_version() < (14, 0), reason="Bug fixed in macOS 14") def test_slice_by_index(): INPUT_SHAPE = (1, 2, 8, 16) @@ -1061,8 +1060,7 @@ def prog(x): np.testing.assert_allclose(y_numpy, y_mlprogram) @staticmethod - @pytest.mark.skipif(ct.utils._macos_version() < (14, 0), - reason="Bug fixed in macOS 14") + @pytest.mark.skipif(ct.utils._macos_version() < (14, 0), reason="Bug fixed in macOS 14") def test_slice_by_index_slice_squeeze_separate(): INPUT_SHAPE = (1, 2, 8, 16) diff --git a/coremltools/converters/mil/mil/ops/tests/iOS16/test_constexpr_ops.py b/coremltools/converters/mil/mil/ops/tests/iOS16/test_constexpr_ops.py index 452925e49..efd157257 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS16/test_constexpr_ops.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS16/test_constexpr_ops.py @@ -68,7 +68,10 @@ def test_builder_eval(self): scale=np.float32(2), axis=0, ) - np.testing.assert_allclose(np.float32([[0, 2, 4], [0, 2, 4]]), v.val) + assert v.val is None + np.testing.assert_allclose( + np.float32([[0, 2, 4], [0, 2, 4]]), v.op.materialized_val_inference() + ) # vector zero-point & scalar scale v = mb.constexpr_affine_dequantize( @@ -77,7 +80,9 @@ def test_builder_eval(self): scale=np.float32(2), axis=0, ) - np.testing.assert_allclose(np.float32([[0, 2, 4], [-2, 0, 2]]), v.val) + np.testing.assert_allclose( + np.float32([[0, 2, 4], [-2, 0, 2]]), v.op.materialized_val_inference() + ) # scalar zero-point & vector scale v = mb.constexpr_affine_dequantize( @@ -86,7 +91,9 @@ def test_builder_eval(self): scale=np.array([2, 4]).astype(np.float32), axis=0, ) - np.testing.assert_allclose(np.float32([[0, 2, 4], [0, 4, 8]]), v.val) + np.testing.assert_allclose( + np.float32([[0, 2, 4], [0, 4, 8]]), v.op.materialized_val_inference() + ) # vector zero-point & vector scale v = mb.constexpr_affine_dequantize( @@ -95,7 +102,9 @@ def test_builder_eval(self): scale=np.array([2, 4]).astype(np.float32), axis=0, ) - np.testing.assert_allclose(np.float32([[0, 2, 4], [-4, 0, 4]]), v.val) + np.testing.assert_allclose( + np.float32([[0, 2, 4], [-4, 0, 4]]), v.op.materialized_val_inference() + ) @staticmethod def affine_dequant_config_generator(): @@ -225,7 +234,8 @@ def build(x): @ssa_fn def test_builder_eval(self): v = mb.constexpr_cast(source_val=np.float16([1, 2]), output_dtype="fp32") - np.testing.assert_allclose(np.float32([1, 2]), v.val) + assert v.val is None + np.testing.assert_allclose(np.float32([1, 2]), v.op.materialized_val_inference()) @staticmethod def cast_config_generator(): @@ -351,7 +361,10 @@ def test_builder_eval(self): ] ).astype(np.uint32), ) - np.testing.assert_allclose(np.float32([3, 3, 1, 1, 1]).astype(np.float32), v.val) + assert v.val is None + np.testing.assert_allclose( + np.float32([3, 3, 1, 1, 1]).astype(np.float32), v.op.materialized_val_inference() + ) @staticmethod def lut_config_generator(): @@ -480,7 +493,10 @@ def test_builder_eval(self): ] ).astype(np.uint32), ) - np.testing.assert_allclose(np.float32([1.0, 2.0, 0.0, 4.0]), v.val) + assert v.val is None + np.testing.assert_allclose( + np.float32([1.0, 2.0, 0.0, 4.0]), v.op.materialized_val_inference() + ) @staticmethod def sparse_config_generator(): diff --git a/coremltools/converters/mil/mil/ops/tests/iOS16/test_conv.py b/coremltools/converters/mil/mil/ops/tests/iOS16/test_conv.py new file mode 100644 index 000000000..5782f39c9 --- /dev/null +++ b/coremltools/converters/mil/mil/ops/tests/iOS16/test_conv.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import itertools + +import numpy as np +import pytest + +from coremltools.converters.mil import testing_reqs +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.ops.tests.iOS16 import backends +from coremltools.converters.mil.testing_utils import get_op_types_in_program + +compute_units = testing_reqs.compute_units + + +class TestConvolution: + @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) + def test_type_inference_with_constexpr_ops(self, compute_unit, backend): + # Test the type inference of the conv op doesn't error out for constexpr bias + @mb.program( + input_specs=[mb.TensorSpec(shape=(1, 3, 4, 4), dtype=types.fp32)], + opset_version=backend.opset_version, + ) + def prog(x): + weight = np.random.rand(2, 3, 2, 2) + bias = mb.constexpr_affine_dequantize( + quantized_data=np.array([1, 2]).astype(np.uint8), + zero_point=np.uint8(1), + scale=np.float32(2), + axis=0, + ) + return mb.conv(x=x, weight=weight, bias=bias) + + assert get_op_types_in_program(prog) == ["constexpr_affine_dequantize", "conv"] + + # Test conv op can have dilations with constexpr weight + @mb.program( + input_specs=[mb.TensorSpec(shape=(1, 3, 4, 4), dtype=types.fp32)], + opset_version=backend.opset_version, + ) + def prog(x): + weight = mb.constexpr_affine_dequantize( + quantized_data=np.array(range(24)).astype(np.uint8).reshape(2, 3, 2, 2), + zero_point=np.uint8(1), + scale=np.float32(2), + axis=0, + ) + return mb.conv(x=x, weight=weight, dilations=[2, 2]) + + assert get_op_types_in_program(prog) == ["constexpr_affine_dequantize", "conv"] + + # Test conv op can have dilations with constexpr weight with casts + @mb.program( + input_specs=[mb.TensorSpec(shape=(1, 3, 4, 4), dtype=types.fp16)], + opset_version=backend.opset_version, + ) + def prog(x): + weight = mb.constexpr_affine_dequantize( + quantized_data=np.array(range(24)).astype(np.uint8).reshape(2, 3, 2, 2), + zero_point=np.uint8(1), + scale=np.float16(2), + axis=0, + ) + cast_weight = mb.cast(x=weight, dtype="fp32") + cast_weight = mb.cast(x=weight, dtype="fp16") + return mb.conv(x=x, weight=cast_weight, dilations=[2, 2]) + + assert get_op_types_in_program(prog) == [ + "constexpr_affine_dequantize", + "cast", + "cast", + "conv", + ] diff --git a/coremltools/converters/mil/mil/ops/tests/test_utils.py b/coremltools/converters/mil/mil/ops/tests/test_utils.py index 82e22c743..69a36c124 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_utils.py +++ b/coremltools/converters/mil/mil/ops/tests/test_utils.py @@ -3,10 +3,18 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import itertools + import numpy as np +import pytest from coremltools.converters.mil.mil.ops.defs._utils import ( - aggregated_pad, effective_kernel, spatial_dimensions_out_shape) + aggregated_pad, + effective_kernel, + pack_elements_into_bits, + restore_elements_from_packed_bits, + spatial_dimensions_out_shape, +) class TestDilation: @@ -260,3 +268,47 @@ def test_same_padding_shape_dilation_2(self): expected = [5, 5] np.testing.assert_equal(actual, expected) + + +class TestPackUnpackBits: + def test_pack_basic(self): + """ + Original data: [-8, 7, 3, 4, -2]. + The 4-bit binary representation for those elements are: + -8: 1000; + 7: 0111; + 3: 0011 + 4: 0100 + -2: 1110 + Hence the packed quantized_data will be 3 bytes long, i.e., 24 bits long, which is: + 0111 1000 0100 0011 0000 1110 + So the packed data is represented by 3 uint8 values: [120, 67, 14]. + """ + original_data = np.array([-8, 7, 3, 4, -2], dtype=np.int8) + expected_packed_data = np.array([120, 67, 14], dtype=np.uint8) + packed_data = pack_elements_into_bits(original_data, nbits=4) + np.testing.assert_array_equal(packed_data, expected_packed_data) + + def test_pack_basic_2(self): + original_data = np.array([1, 2, 3, 4, 5], dtype=np.int8) + expected_packed_data = np.array([33, 67, 5], dtype=np.uint8) + packed_data = pack_elements_into_bits(original_data, nbits=4) + np.testing.assert_array_equal(packed_data, expected_packed_data) + + @pytest.mark.parametrize( + "nbits, data_dtype, element_num", + itertools.product(list(range(1, 9)), [np.int8, np.uint8], [1, 3, 20]), + ) + def test_round_trip_pack_unpack(self, nbits, data_dtype, element_num): + is_data_signed = np.issubdtype(data_dtype, np.signedinteger) + low, high = 0, 2**nbits + if is_data_signed: + low, high = -(2 ** (nbits - 1)), 2 ** (nbits - 1) + original_data = np.random.randint(low=low, high=high, size=(element_num,)).astype( + data_dtype + ) + packed_data = pack_elements_into_bits(original_data, nbits) + restored_data = restore_elements_from_packed_bits( + packed_data, nbits, element_num, are_packed_values_signed=is_data_signed + ) + np.testing.assert_array_equal(restored_data, original_data) diff --git a/coremltools/converters/mil/mil/passes/defs/cleanup/const_deduplication.py b/coremltools/converters/mil/mil/passes/defs/cleanup/const_deduplication.py index 2a8373d8f..5875fed55 100644 --- a/coremltools/converters/mil/mil/passes/defs/cleanup/const_deduplication.py +++ b/coremltools/converters/mil/mil/passes/defs/cleanup/const_deduplication.py @@ -37,26 +37,18 @@ class const_deduplication(AbstractGraphPass): q_embedding = linear(x=q, weight=weight_q, bias=bias_q) k_embedding = linear(x=k, weight=weight_q, bias=bias_q) - Concretely, we consider a constant as duplicated if there exists such a previous constant that: + Concretely, this graph pass consists of two stages: - 1. has same dtype and value + (1) Deduplication of ``const`` op: - 2. comes from same type of op + We consider a ``const`` as duplicated if there exists such a previous ``const`` that has same dtype and value - The reason why op type is considered is, there are 2 types of constants in Core ML: + (2) Deduplication of ``constexpr_*`` op: - 1. The usual constant, i.e., the output of ``const`` op - - 2. The result of const expression, i.e., the output of ``constexpr_*`` ops + We consider a ``constexpr_*`` as duplicated if there exists such a previous ``constexpr_*`` that has the same ``op_type`` and input attributes. """ NUMEL_THRESH = 100 - CONSTEXPR_OPS = { - "constexpr_affine_dequantize", - "constexpr_cast", - "constexpr_lut_to_dense", - "constexpr_sparse_to_dense", - } DTYPE2ATOL = { types.fp16: 6e-8, types.fp32: 1e-12, @@ -66,13 +58,9 @@ def apply(self, prog) -> None: for f in prog.functions.values(): self._constant_deduplication_block(f) - @block_context_manager - def _constant_deduplication_block(self, block: Block) -> None: - for op in list(block.operations): - for b in op.blocks: - self._constant_deduplication_block(b) - - unique2duplicates = self.find_constants(block) + def remove_duplicate_ops( + self, block: Block, unique2duplicates: Dict[Var, List[Var]], force_replace: bool + ) -> None: for unique in unique2duplicates: for duplicate in unique2duplicates[unique]: if duplicate in block.outputs: @@ -82,10 +70,54 @@ def _constant_deduplication_block(self, block: Block) -> None: anchor_op=op, old_var=duplicate, new_var=unique, - force_replace=True if op.op_type in self.CONSTEXPR_OPS else False, + force_replace=force_replace, ) block.remove_ops([op]) + @block_context_manager + def _constant_deduplication_block(self, block: Block) -> None: + for op in list(block.operations): + for b in op.blocks: + self._constant_deduplication_block(b) + + # Deduplication of ``const`` op + unique2duplicates_const = self.find_constants(block) + self.remove_duplicate_ops(block, unique2duplicates_const, force_replace=False) + + # Deduplication of ``constexpr_*`` op + # Note that, the ``find_constexpr`` must go after ``find_constants`` + ``remove_duplicate_ops`` for ``const`` ops. + # Since after the above two functions, ``const`` ops with identical values are + # deduplicated into a single ``Var`` object, which allows ``find_constexpr`` to + # directly compare the ``const`` input attr pointers instead of the actual values. + unique2duplicates_constexpr = self.find_constexprs(block) + self.remove_duplicate_ops(block, unique2duplicates_constexpr, force_replace=True) + + def find_constexprs(self, block: Block) -> Dict[Var, List[Var]]: + """ + Given a block, return all constexpr in the block in such a format: + {unique_var_0: [duplicated_var_0_0, duplicated_var_0_1, ...], + unique_var_1: [duplicated_var_1_0, duplicated_var_1_1, ...], + ... + } + """ + hashkey_2_duplicates: Dict[Tuple, List[Var]] = {} + for op in list(block.operations): + if "constexpr" in op.op_type: + hash_key = [op.op_type] + for v in op.inputs.values(): + hash_key.append(v.dtype) + if np.prod(v.shape) < self.NUMEL_THRESH: + hash_key.append(str(v.val)) + else: + hash_key.append(v) + hash_key = tuple(hash_key) + if hash_key not in hashkey_2_duplicates: + hashkey_2_duplicates[hash_key] = [op.outputs[0]] + else: + hashkey_2_duplicates[hash_key].append(op.outputs[0]) + + return {v[0]: v[1:] for v in hashkey_2_duplicates.values()} + def find_constants(self, block: Block) -> Dict[Var, List[Var]]: """ Given a block, return all constants in the block in such a format: @@ -99,8 +131,7 @@ def find_constants(self, block: Block) -> Dict[Var, List[Var]]: # instead of brute-force C_N^2 comparison, use a hash map to be O(N) constant_dict: Dict[Tuple[str, types.type, Tuple[int], str], List[Var]] = {} for op in list(block.operations): - op_type = op.op_type - if op_type == "const" or op_type in self.CONSTEXPR_OPS: + if op.op_type == "const": constant_var = op.outputs[0] if isinstance(constant_var, ListVar): continue @@ -115,7 +146,7 @@ def find_constants(self, block: Block) -> Dict[Var, List[Var]]: hash = hashlib.sha1( np.ascontiguousarray(value.reshape(-1)[: self.NUMEL_THRESH]) ).hexdigest() - key = (op_type, dtype, shape, hash) + key = (dtype, shape, hash) if key not in constant_dict: constant_dict[key] = [constant_var] diff --git a/coremltools/converters/mil/mil/passes/defs/cleanup/remove_redundant_ops.py b/coremltools/converters/mil/mil/passes/defs/cleanup/remove_redundant_ops.py index 19f9cf339..a8e316fab 100644 --- a/coremltools/converters/mil/mil/passes/defs/cleanup/remove_redundant_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/cleanup/remove_redundant_ops.py @@ -5,8 +5,11 @@ import collections +import numpy as np + +from coremltools.converters.mil.mil import Var from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass -from coremltools.converters.mil.mil.passes.helper import _are_ops_identical, block_context_manager +from coremltools.converters.mil.mil.passes.helper import block_context_manager from coremltools.converters.mil.mil.passes.pass_registry import register_pass @@ -106,6 +109,69 @@ def _get_candidate_ops_lists_from_var(var): return candidate_ops_lists + @staticmethod + def _are_ops_identical(op1, op2): + """ + Return True, if all inputs of op1 and op2 are identical. + non-constant inputs must refer to the same object. + + For constant inputs, we only compare arrays with small size. + Large size const ops are already deduplicated in the const_deduplication pass so we + can compare the pointers. + """ + + def _are_values_identical(val1, val2): + if not isinstance(val1, np.ndarray) or not isinstance(val2, np.ndarray): + return np.array_equal(np.array(val1), np.array(val2)) + if val1.size != val2.size: + return False + if val1.size < 100: + return np.array_equal(val1, val2) + return False + + def _are_vars_identical(var1, var2): + if var1 is var2: + return True + if var1.val is None and var2.val is None: + if var1 != var2: + return False + elif var1.val is not None and var2.val is not None: + if var1.dtype != var2.dtype: + return False + if not _are_values_identical(var1.val, var2.val): + return False + else: + return False + return True + + if op1 == op2: + return True + if op1.op_type != op2.op_type: + return False + if len(op1.inputs) != len(op2.inputs): + return False + + for key, value1 in op1.inputs.items(): + if key not in op2.inputs: + return False + value2 = op2.inputs[key] + if isinstance(value1, Var) and isinstance(value2, Var): + if not _are_vars_identical(value1, value2): + return False + elif isinstance(value1, (list, tuple)) and isinstance(value2, (list, tuple)): + if len(value1) != len(value2): + return False + else: + for i, v in enumerate(value1): + if not _are_vars_identical(v, value2[i]): + return False + else: + return False + + assert len(op1.blocks) == 0, "this method does not handle ops that have blocks in it" + assert len(op2.blocks) == 0, "this method does not handle ops that have blocks in it" + return True + @staticmethod def _try_to_remove_ops(candidate_ops_list): # candidate_ops_list contains ops in topological order. @@ -126,7 +192,7 @@ def _try_to_remove_ops(candidate_ops_list): ops_to_remove = [] for op in candidate_ops_list[1:]: if op.outputs[0] not in block.outputs: # to make sure we don't remove an output op - if _are_ops_identical(first_op, op): + if remove_redundant_ops._are_ops_identical(first_op, op): ops_to_remove.append(op) if len(ops_to_remove) == 0: diff --git a/coremltools/converters/mil/mil/passes/defs/optimize_conv.py b/coremltools/converters/mil/mil/passes/defs/optimize_conv.py index f31360f12..0dbbd0236 100644 --- a/coremltools/converters/mil/mil/passes/defs/optimize_conv.py +++ b/coremltools/converters/mil/mil/passes/defs/optimize_conv.py @@ -394,6 +394,10 @@ def _try_to_transform(conv_op, bn_op): conv_bias = np.zeros(Cout) else: conv_bias = conv_bias.val + + if conv_bias is None: + return False + conv_bias = conv_bias.astype(conv_weight_type) # get the original shape of weight and bias diff --git a/coremltools/converters/mil/mil/passes/defs/optimize_linear.py b/coremltools/converters/mil/mil/passes/defs/optimize_linear.py index b72f30f78..e59103a5d 100644 --- a/coremltools/converters/mil/mil/passes/defs/optimize_linear.py +++ b/coremltools/converters/mil/mil/passes/defs/optimize_linear.py @@ -58,6 +58,10 @@ def _try_to_transform(linear_op, add_or_sub_op, block): is_sub = add_or_sub_op.op_type == "sub" is_first_input = add_or_sub_op.x == linear_op.outputs[0] + # Return if weight or bias are missing values + if linear_op.weight.val is None or linear_op.bias.val is None: + return False + # compute the new bias linear_bias = linear_op.bias.val bias = add_or_sub_op.y.val if is_first_input else add_or_sub_op.x.val @@ -184,7 +188,7 @@ def _find_candidate_op(op): def _transpose(v, before_op, name=None): """ Transpose the last 2 dims. - + - ``v``: (Var, must be a tensor). - ``before_op``: (Operation) The op right before the newly added ``transpose`` op. - ``name``: Name for the ``transpose`` op if provided. diff --git a/coremltools/converters/mil/mil/passes/helper.py b/coremltools/converters/mil/mil/passes/helper.py index 2ce4d2f12..1bf1e70c3 100644 --- a/coremltools/converters/mil/mil/passes/helper.py +++ b/coremltools/converters/mil/mil/passes/helper.py @@ -7,7 +7,7 @@ import numpy as np -from coremltools.converters.mil.mil import Block, Operation, Var +from coremltools.converters.mil.mil import Block, Operation from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass class classproperty(property): @@ -140,55 +140,3 @@ def _check_var_scalar_value(x, val, tol=1e-3): if abs(x_val - val) < tol: return True return False - -def _are_ops_identical(op1, op2): - ''' - Return True, if all inputs of op1 and op2 are identical. - non-constant inputs must refer to the same object, and constant inputs must have the same value - ''' - - def _are_values_identical(val1, val2): - np_arr1 = np.array(val1) - np_arr2 = np.array(val2) - return np.array_equal(np_arr1, np_arr2) - - def _are_vars_identical(var1, var2): - if var1.val is None and var2.val is None: - if var1 != var2: - return False - elif var1.val is not None and var2.val is not None: - if var1.dtype != var2.dtype: - return False - if not _are_values_identical(var1.val, var2.val): - return False - else: - return False - return True - - if op1 == op2: - return True - if op1.op_type != op2.op_type: - return False - if len(op1.inputs) != len(op2.inputs): - return False - - for key, value1 in op1.inputs.items(): - if key not in op2.inputs: - return False - value2 = op2.inputs[key] - if isinstance(value1, Var) and isinstance(value2, Var): - if not _are_vars_identical(value1, value2): - return False - elif isinstance(value1, (list, tuple)) and isinstance(value2, (list, tuple)): - if len(value1) != len(value2): - return False - else: - for i, v in enumerate(value1): - if not _are_vars_identical(v, value2[i]): - return False - else: - return False - - assert len(op1.blocks) == 0, "this method does not handle ops that have blocks in it" - assert len(op2.blocks) == 0, "this method does not handle ops that have blocks in it" - return True diff --git a/coremltools/converters/mil/mil/passes/tests/test_passes.py b/coremltools/converters/mil/mil/passes/tests/test_passes.py index 3debe4e30..f1aa8598d 100644 --- a/coremltools/converters/mil/mil/passes/tests/test_passes.py +++ b/coremltools/converters/mil/mil/passes/tests/test_passes.py @@ -42,6 +42,77 @@ _VALIDATE_MODEL = True +def _get_constexpr_cast(shape, seed=None): + if seed is not None: + np.random.seed(seed) + val = np.random.rand(*shape).astype(np.float16) + return mb.constexpr_cast(source_val=val, output_dtype="fp32") + + +def _get_constexpr_sparse_to_dense(shape, seed=None): + if seed is not None: + np.random.seed(seed) + val = np.random.rand(*shape) + sparse_params = cto.coreml._quantization_passes.prune_weights.compress_by_magnitude( + val=val, target_sparsity=0.4 + ) + return mb.constexpr_sparse_to_dense( + nonzero_data=sparse_params.nonzero_data, + mask=sparse_params.mask, + shape=np.uint32(sparse_params.shape), + ) + + +def _get_constexpr_lut_to_dense(shape, seed=None): + if seed is not None: + np.random.seed(seed) + val = np.random.rand(*shape) + lut_params = cto.coreml._quantization_passes.palettize_weights.compress( + val=val, nbits=4, mode="UNIFORM" + ) + return mb.constexpr_lut_to_dense( + indices=lut_params.indices, + lut=lut_params.lut, + shape=np.uint32(lut_params.shape), + ) + + +def _get_constexpr_affine_dequantize(shape, seed=None): + if seed is not None: + np.random.seed(seed) + val = np.random.rand(*shape) + quant_params = cto.coreml._quantization_passes.linear_quantize_weights.compress( + val=val, axis=0, mode="LINEAR_SYMMETRIC", dtype=types.uint8 + ) + return mb.constexpr_affine_dequantize( + quantized_data=quant_params.quantized_data, + zero_point=quant_params.zero_point, + scale=quant_params.scale, + axis=quant_params.axis, + ) + + +def _get_constexpr_val(constexpr_var): + assert "constexpr" in constexpr_var.op.op_type + if constexpr_var.val is not None: + return constexpr_var.val + return constexpr_var.op.materialized_val_inference() + + +CONSTEXPR_FUNCS = { + "constexpr_cast": _get_constexpr_cast, + "constexpr_sparse_to_dense": _get_constexpr_sparse_to_dense, + "constexpr_lut_to_dense": _get_constexpr_lut_to_dense, + "constexpr_affine_dequantize": _get_constexpr_affine_dequantize, +} + +CONSTEXPR_OPS = [ + "constexpr_cast", + "constexpr_sparse_to_dense", + "constexpr_lut_to_dense", + "constexpr_affine_dequantize", +] + class TestConstDeduplication: def test_const_deduplication(self): BATCH_DIM = 5 @@ -67,16 +138,15 @@ def prog(q, k): assert_op_count_match(prev_prog, expect=6, op="const") assert_op_count_match(prog, expect=4, op="const") - def test_constexpr_deduplication(self): + @pytest.mark.parametrize( + "constexpr_op", + CONSTEXPR_OPS, + ) + def test_constexpr_deduplication(self, constexpr_op): BATCH_DIM = 5 SEQUENCE_LENGTH = 4 ENCODING_DIM = 256 EMBEDDING_DIM = 128 - quantized_weight = np.random.randint( - -128, 128, size=(EMBEDDING_DIM, ENCODING_DIM), dtype=np.int8 - ) - quantized_bias = np.random.randint(-128, 128, size=EMBEDDING_DIM, dtype=np.int8) - @mb.program( input_specs=[ mb.TensorSpec(shape=(BATCH_DIM, SEQUENCE_LENGTH, ENCODING_DIM)), @@ -84,38 +154,18 @@ def test_constexpr_deduplication(self): ] ) def prog(q, k): - weight_q = mb.constexpr_affine_dequantize( - quantized_data=quantized_weight, - zero_point=np.int8(0), - scale=np.float32(1.0), - axis=0, - ) - weight_k = mb.constexpr_affine_dequantize( - quantized_data=quantized_weight, - zero_point=np.int8(0), - scale=np.float32(1.0), - axis=0, - ) - bias_q = mb.constexpr_affine_dequantize( - quantized_data=quantized_bias, - zero_point=np.int8(0), - scale=np.float32(1.0), - axis=0, - ) - bias_k = mb.constexpr_affine_dequantize( - quantized_data=quantized_bias, - zero_point=np.int8(0), - scale=np.float32(1.0), - axis=0, - ) + weight_q = CONSTEXPR_FUNCS[constexpr_op]((EMBEDDING_DIM, ENCODING_DIM), seed=19) + weight_k = CONSTEXPR_FUNCS[constexpr_op]((EMBEDDING_DIM, ENCODING_DIM), seed=19) + bias_q = CONSTEXPR_FUNCS[constexpr_op]((EMBEDDING_DIM,), seed=29) + bias_k = CONSTEXPR_FUNCS[constexpr_op]((EMBEDDING_DIM,), seed=29) q_e = mb.linear(x=q, weight=weight_q, bias=bias_q) k_e = mb.linear(x=k, weight=weight_k, bias=bias_k) attention = mb.matmul(x=q_e, y=k_e, transpose_y=True) return attention prev_prog, _, _ = apply_pass_and_basic_check(prog, "common::const_deduplication") - assert_op_count_match(prev_prog, expect=4, op="constexpr_affine_dequantize") - assert_op_count_match(prog, expect=2, op="constexpr_affine_dequantize") + assert_op_count_match(prev_prog, expect=4, op=constexpr_op) + assert_op_count_match(prog, expect=2, op=constexpr_op) def test_const_deduplication_as_outputs(self): """ @@ -1384,20 +1434,24 @@ def prog(x): ) @staticmethod - def _make_repeated_conv_prog(redundant_conv=True): + def _make_repeated_conv_prog(redundant_conv=True, out_channel=2): prog = Program() func_inputs = {"x": mb.placeholder(shape=[1, 4, 5, 5])} with Function(func_inputs) as ssa_fun: x = ssa_fun.inputs["x"] x = mb.relu(x=x) - W = np.random.rand(8, 4, 3, 3) + W = np.random.rand(out_channel, 4, 3, 3) if redundant_conv: - bias = np.random.rand(8) + bias = np.random.rand(out_channel) x1 = mb.conv(x=x, weight=W, bias=bias, pad_type="same", strides=[1, 1]) x2 = mb.conv(x=x, weight=W, bias=bias, pad_type="same", strides=[1, 1]) else: - x1 = mb.conv(x=x, weight=W, bias=np.random.rand(8), pad_type="same", strides=[1, 1]) - x2 = mb.conv(x=x, weight=W, bias=np.random.rand(8), pad_type="same", strides=[1, 1]) + x1 = mb.conv( + x=x, weight=W, bias=np.random.rand(out_channel), pad_type="same", strides=[1, 1] + ) + x2 = mb.conv( + x=x, weight=W, bias=np.random.rand(out_channel), pad_type="same", strides=[1, 1] + ) x1 = mb.relu(x=x1) x2 = mb.relu(x=x2) x1 = mb.avg_pool(x=x1, kernel_sizes=[2, 2], strides=[1, 1], pad_type="same") @@ -1436,7 +1490,52 @@ def test_redundant_ops_inside_graph_valid_pattern(self): assert_model_is_valid( prog, {"x": (1, 4, 5, 5)}, - expected_output_shapes={block.outputs[0].name: (1, 16, 5, 5)}, + expected_output_shapes={block.outputs[0].name: (1, 4, 5, 5)}, + ) + + def test_redundant_ops_inside_graph_with_large_const(self): + """ + For the large constants, they need to be deduplicated by the const_deduplication first. + This test is making sure the converter is not doing any "brutal force" comparision. + + Input graph: + input--> relu--------->conv------>relu----> pool ---> concat ---> out + | ^ + | | + |---->conv---->relu---------------------------- + + Output graph: + input-> relu--->conv------>relu----> pool ---> concat ---> out + | ^ + | | + |------------------- + """ + # The remove_redundant_ops is not doing brutal force array comparison + prog = self._make_repeated_conv_prog(redundant_conv=True, out_channel=10) + prev_prog, _, block = apply_pass_and_basic_check(prog, "common::remove_redundant_ops") + ops_in_prev_prog = [ + "relu", + "conv", + "conv", + "relu", + "relu", + "avg_pool", + "concat", + ] + assert get_op_types_in_program(prev_prog) == ops_in_prev_prog + assert get_op_types_in_program(prog) == ops_in_prev_prog + + # We need to first run the const_deduplication pass. + prog = self._make_repeated_conv_prog(redundant_conv=True, out_channel=10) + _, _, block = apply_pass_and_basic_check(prog, "common::const_deduplication") + _, _, block = apply_pass_and_basic_check(prog, "common::dead_code_elimination") + _, _, block = apply_pass_and_basic_check(prog, "common::remove_redundant_ops") + + assert get_op_types_in_program(prog) == ["relu", "conv", "relu", "avg_pool", "concat"] + assert_model_is_valid( + prog, + {"x": (1, 4, 5, 5)}, + expected_output_shapes={block.outputs[0].name: (1, 20, 5, 5)}, ) def test_redundant_ops_inside_graph_invalid_pattern(self): @@ -1470,7 +1569,7 @@ def test_redundant_ops_inside_graph_invalid_pattern(self): assert_model_is_valid( prog, {"x": (1, 4, 5, 5)}, - expected_output_shapes={block.outputs[0].name: (1, 16, 5, 5)}, + expected_output_shapes={block.outputs[0].name: (1, 4, 5, 5)}, ) def test_redundant_op_as_output_valid_pattern_1(self): @@ -2627,61 +2726,6 @@ def prog(x): class TestSkipConstexprOps: - @staticmethod - def _get_constexpr_cast(shape): - val = np.random.rand(*shape).astype(np.float16) - return mb.constexpr_cast(source_val=val, output_dtype="fp32") - - @staticmethod - def _get_constexpr_sparse_to_dense(shape): - val = np.random.rand(*shape) - sparse_params = cto.coreml._quantization_passes.prune_weights.compress_by_magnitude( - val=val, target_sparsity=0.4 - ) - return mb.constexpr_sparse_to_dense( - nonzero_data=sparse_params.nonzero_data, - mask=sparse_params.mask, - shape=np.uint32(sparse_params.shape), - ) - - @staticmethod - def _get_constexpr_lut_to_dense(shape): - val = np.random.rand(*shape) - lut_params = cto.coreml._quantization_passes.palettize_weights.compress(val=val, nbits=4, mode="UNIFORM") - return mb.constexpr_lut_to_dense( - indices=lut_params.indices, - lut=lut_params.lut, - shape=np.uint32(lut_params.shape), - ) - - @staticmethod - def _get_constexpr_affine_dequantize(shape): - val = np.random.rand(*shape) - quant_params = cto.coreml._quantization_passes.linear_quantize_weights.compress( - val=val, axis=0, mode="LINEAR_SYMMETRIC", dtype=types.uint8 - ) - return mb.constexpr_affine_dequantize( - quantized_data=quant_params.quantized_data, - zero_point=quant_params.zero_point, - scale=quant_params.scale, - axis=quant_params.axis, - ) - - # Static method cannot be stored as a function without attribute access. - CONSTEXPR_FUNCS = { - "constexpr_cast": _get_constexpr_cast.__func__, - "constexpr_sparse_to_dense": _get_constexpr_sparse_to_dense.__func__, - "constexpr_lut_to_dense": _get_constexpr_lut_to_dense.__func__, - "constexpr_affine_dequantize": _get_constexpr_affine_dequantize.__func__, - } - - CONSTEXPR_OPS = [ - "constexpr_cast", - "constexpr_sparse_to_dense", - "constexpr_lut_to_dense", - "constexpr_affine_dequantize", - ] - @staticmethod @pytest.mark.parametrize( "constexpr_op", @@ -2707,7 +2751,7 @@ def prog(x): a = np.random.rand( 2, ) - constexpr = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((4, 2)) + constexpr = CONSTEXPR_FUNCS[constexpr_op]((4, 2)) linear = mb.linear(x=a, weight=constexpr) return mb.add(x=x, y=linear) @@ -2734,15 +2778,15 @@ def test_skip_fuse_matmul_weight_bias(constexpr_op, weight_constexpr, bias_const """ def get_matmul(x, weight_constexpr): - weight = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((3, 2)) + weight = CONSTEXPR_FUNCS[constexpr_op]((3, 2)) if not weight_constexpr: - weight = weight.val + weight = _get_constexpr_val(weight) return mb.matmul(x=x, y=weight) def get_add(x, bias_constexpr): - bias = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((2,)) + bias = CONSTEXPR_FUNCS[constexpr_op]((2,)) if not bias_constexpr: - bias = bias.val + bias = _get_constexpr_val(bias) return mb.add(x=x, y=bias) @mb.program(input_specs=[mb.TensorSpec(shape=(1, 3))]) @@ -2793,13 +2837,13 @@ def test_skip_fuse_conv(constexpr_op, op, weight_constexpr, const_constexpr): @mb.program(input_specs=[mb.TensorSpec(shape=input_shape)]) def prog(x): - conv_weight = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((Cout, Cin, 2, 2)) + conv_weight = CONSTEXPR_FUNCS[constexpr_op]((Cout, Cin, 2, 2)) if not weight_constexpr: - conv_weight = conv_weight.val + conv_weight = _get_constexpr_val(conv_weight) x = mb.conv(x=x, weight=conv_weight) - const = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((Cout, 1, 1)) + const = CONSTEXPR_FUNCS[constexpr_op]((Cout, 1, 1)) if not const_constexpr: - const = const.val + const = _get_constexpr_val(const) return getattr(mb, op)(x=x, y=const) apply_pass_and_basic_check(prog, "common::fuse_conv_scale") @@ -2842,13 +2886,13 @@ def test_skip_fuse_linear_bias(constexpr_op, weight_constexpr, bias_constexpr): @mb.program(input_specs=[mb.TensorSpec(shape=(2,))]) def prog(x): - weight = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((4, 2)) + weight = CONSTEXPR_FUNCS[constexpr_op]((4, 2)) if not weight_constexpr: - weight = weight.val + weight = _get_constexpr_val(weight) linear = mb.linear(x=x, weight=weight) - bias = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((4,)) + bias = CONSTEXPR_FUNCS[constexpr_op]((4,)) if not bias_constexpr: - bias = bias.val + bias = _get_constexpr_val(bias) return mb.add(x=linear, y=bias) apply_pass_and_basic_check(prog, "common::fuse_linear_bias") @@ -2894,12 +2938,12 @@ def test_skip_fuse_conv_batchnorm(constexpr_op, weight_constexpr, bias_constexpr @mb.program(input_specs=[mb.TensorSpec(shape=input_shape)]) def prog(x): # conv layer - weight = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((Cout, Cin, 2, 2)) + weight = CONSTEXPR_FUNCS[constexpr_op]((Cout, Cin, 2, 2)) if not weight_constexpr: - weight = weight.val - bias = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((Cout,)) + weight = _get_constexpr_val(weight) + bias = CONSTEXPR_FUNCS[constexpr_op]((Cout,)) if not bias_constexpr: - bias = bias.val + bias = _get_constexpr_val(bias) x = mb.conv( x=x, @@ -3217,7 +3261,7 @@ def prog(x): return x prev_prog, _, block = apply_pass_and_basic_check(prog, "common::expand_high_rank_reshape_and_transpose") - prog._check_invalid_tensor_rank() + prog._check_invalid_program() assert get_op_types_in_program(prog) == ["reshape", "transpose", "reshape"] TestExpandHighRankReshapeAndTranspose._test_numerical(prev_prog, input_shape, reshape_shape, perm, output_shape) @@ -3235,7 +3279,7 @@ def prog(x): return x prev_prog, _, block = apply_pass_and_basic_check(prog, "common::expand_high_rank_reshape_and_transpose") - prog._check_invalid_tensor_rank() + prog._check_invalid_program() assert get_op_types_in_program(prog) == ["reshape", "transpose", "reshape"] TestExpandHighRankReshapeAndTranspose._test_numerical(prev_prog, input_shape, reshape_shape, perm, output_shape) @@ -3254,7 +3298,7 @@ def prog(x): prev_prog, _, block = apply_pass_and_basic_check(prog, "common::expand_high_rank_reshape_and_transpose") - prog._check_invalid_tensor_rank() + prog._check_invalid_program() assert get_op_types_in_program(prog) == ["reshape", "transpose"] * 16 + ["reshape"] TestExpandHighRankReshapeAndTranspose._test_numerical(prev_prog, input_shape, reshape_shape, perm, output_shape) @@ -3274,7 +3318,7 @@ def prog(x): prev_prog, _, block = apply_pass_and_basic_check(prog, "common::expand_high_rank_reshape_and_transpose") with pytest.raises(ValueError, match="Core ML only supports tensors with rank <= 5"): - prog._check_invalid_tensor_rank() + prog._check_invalid_program() class TestMergeConsecutiveRelus: diff --git a/coremltools/converters/mil/mil/program.py b/coremltools/converters/mil/mil/program.py index 1468a6100..fe103d20b 100644 --- a/coremltools/converters/mil/mil/program.py +++ b/coremltools/converters/mil/mil/program.py @@ -7,11 +7,11 @@ import sympy as _sm from coremltools import _logger as logger -from coremltools.converters.mil._deployment_compatibility import \ - AvailableTarget as _target +from coremltools.converters.mil._deployment_compatibility import AvailableTarget as _target from coremltools.converters.mil.input_types import InputType -from coremltools.converters.mil.mil.var import ListVar +from coremltools.converters.mil.mil.input_type import InternalInputType from coremltools.converters.mil.mil.ops.helper import _get_version_of_op +from coremltools.converters.mil.mil.var import ListVar from . import types from .block import Function @@ -95,10 +95,13 @@ def _check_program_opset_version(self): self._check_ops_version_compatibility(max_opset_version) self._check_or_set_functions_opset_version(max_opset_version) - def _check_invalid_tensor_rank(self): - ''' - Early error out for tensor with rank >= 6 - ''' + def _check_invalid_program(self): + """ + Early error out for + 1. tensor with rank >= 6 + 2. non const tensor feed in const input + """ + def _check_invalid_tensor_rank_block(block): for op in block.operations: for b in op.blocks: @@ -109,9 +112,30 @@ def _check_invalid_tensor_rank_block(block): f'Core ML only supports tensors with rank <= 5. Layer "{op.name}", ' f'with type "{op.op_type}", outputs a rank {o.rank} tensor. ' ) + + def _check_invalid_const_tensor_input_block(block): + for op in block.operations: + for b in op.blocks: + _check_invalid_const_tensor_input_block(b) + + for k, v in op.inputs.items(): + input_type = op.input_spec.input_types[k] + + if ( + input_type.const + and not isinstance(input_type, InternalInputType) + and not (v.op.op_type.startswith("constexpr_") or v.val is not None) + ): + raise ValueError( + f"In op {op.name}. Input {k} ({v.name}) must be const or constexpr ops." + ) + for f in self.functions.values(): _check_invalid_tensor_rank_block(f) + for f in self.functions.values(): + _check_invalid_const_tensor_input_block(f) + def add_function(self, name, ssa_func): if not isinstance(ssa_func, Function): raise ValueError("Only Function can be added to Program.") diff --git a/coremltools/converters/mil/mil/tests/test_block.py b/coremltools/converters/mil/mil/tests/test_block.py index a4ccfe275..d0674920d 100644 --- a/coremltools/converters/mil/mil/tests/test_block.py +++ b/coremltools/converters/mil/mil/tests/test_block.py @@ -9,7 +9,7 @@ import pytest from coremltools.converters.mil.mil import Builder as mb -from coremltools.converters.mil.mil.passes.tests.test_passes import TestSkipConstexprOps +from coremltools.converters.mil.mil.passes.tests.test_passes import CONSTEXPR_FUNCS from coremltools.converters.mil.testing_utils import ( assert_same_output_names, assert_same_output_shapes, @@ -271,7 +271,7 @@ def test_replace_nonreplaceable_vars(): constexpr_op = "constexpr_sparse_to_dense" @mb.program(input_specs=[mb.TensorSpec(shape=(4, 2))]) def prog(x): - constexpr = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((4, 2)) + constexpr = CONSTEXPR_FUNCS[constexpr_op]((4, 2)) return mb.add(x=x, y=constexpr) block = prog.functions["main"] @@ -297,7 +297,7 @@ def test_replace_nonreplaceable_vars_force(): constexpr_op = "constexpr_sparse_to_dense" @mb.program(input_specs=[mb.TensorSpec(shape=(4, 2))]) def prog(x): - constexpr = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((4, 2)) + constexpr = CONSTEXPR_FUNCS[constexpr_op]((4, 2)) return mb.add(x=x, y=constexpr) block = prog.functions["main"] diff --git a/coremltools/converters/mil/mil/tests/test_programs.py b/coremltools/converters/mil/mil/tests/test_programs.py index fb0d2ea41..4fbbd69a6 100644 --- a/coremltools/converters/mil/mil/tests/test_programs.py +++ b/coremltools/converters/mil/mil/tests/test_programs.py @@ -10,6 +10,7 @@ from coremltools import _logger as logger from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.passes.tests.test_passes import CONSTEXPR_FUNCS np.random.seed(0) @@ -324,15 +325,15 @@ def test_rank6_tensor_early_error_out(): ''' The builder should error out early when detecting a rank 6 (or higher) tensor which cannot be eliminated by graph passes ''' + @mb.program(input_specs=[mb.TensorSpec(shape=(1,), dtype=types.fp32)]) + def prog(x): + res = mb.reshape(x=x, shape=(1, 1, 1, 1, 1, 1), name="reshape_0") + return res + expected_err_str = ( "Core ML only supports tensors with rank <= 5. Layer \"reshape_0\", with type \"reshape\", outputs a rank 6 tensor" ) with pytest.raises(ValueError, match=expected_err_str): - @mb.program(input_specs=[mb.TensorSpec(shape=(1,), dtype=types.fp32)]) - def prog(x): - res = mb.reshape(x=x, shape=(1, 1, 1, 1, 1, 1), name="reshape_0") - return res - ct.convert( prog, source="milinternal", @@ -359,3 +360,65 @@ def prog(x): name="list_0", ) return ls + + @staticmethod + def test_invalid_const_input_early_error_out(): + """ + The following program: + + constexpr -> transpose -> linear + + will not error out during the front end conversion, even though the weight of + linear op needs to be const / constexpr directly. + + It is going to error out after all the optimization graph passes are finished, + and transpose remains. + + However, if transpose can be removed, the conversion goes through. + """ + # Test a simple constexpr op fed into linear + @mb.program(input_specs=[mb.TensorSpec(shape=(2, 3))]) + def prog(x): + constexpr = CONSTEXPR_FUNCS["constexpr_affine_dequantize"]((4, 3)) + return mb.linear(x=x, weight=constexpr) + + for compute_precision in [ct.precision.FLOAT32, ct.precision.FLOAT16]: + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS16, + compute_units=ct.ComputeUnit.CPU_ONLY, + compute_precision=compute_precision, + ) + + # Additional pattern (transpose) after constexpr will cause an early error out + @mb.program(input_specs=[mb.TensorSpec(shape=(2, 3))]) + def prog(x): + constexpr = CONSTEXPR_FUNCS["constexpr_affine_dequantize"]((3, 4)) + constexpr = mb.transpose(x=constexpr, perm=[1, 0]) + return mb.linear(x=x, weight=constexpr) + + for compute_precision in [ct.precision.FLOAT32, ct.precision.FLOAT16]: + with pytest.raises(ValueError, match="must be const or constexpr ops"): + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS16, + compute_units=ct.ComputeUnit.CPU_ONLY, + compute_precision=compute_precision, + ) + + # If the transpose is removed by optimization passes, the conversion goes through + @mb.program(input_specs=[mb.TensorSpec(shape=(2, 3))]) + def prog(x): + constexpr = CONSTEXPR_FUNCS["constexpr_affine_dequantize"]((4, 3)) + constexpr = mb.transpose(x=constexpr, perm=[0, 1]) + return mb.linear(x=x, weight=constexpr) + + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS16, + compute_units=ct.ComputeUnit.CPU_ONLY, + compute_precision=compute_precision, + ) diff --git a/coremltools/converters/mil/mil/types/type_int.py b/coremltools/converters/mil/mil/types/type_int.py index 61b0149ac..2080d5b45 100644 --- a/coremltools/converters/mil/mil/types/type_int.py +++ b/coremltools/converters/mil/mil/types/type_int.py @@ -165,9 +165,8 @@ def __neg__(self): uint64 = make_int(64, "u") uint = uint64 +_INT_TYPES = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) + def is_int(t): - return any( - t is i or isinstance(t, i) - for i in [int8, int16, int32, int64, uint8, uint16, uint32, uint64] - ) + return any(t is i or isinstance(t, i) for i in _INT_TYPES) diff --git a/coremltools/converters/mil/mil/types/type_tensor.py b/coremltools/converters/mil/mil/types/type_tensor.py index 45ac510a9..71f400f18 100644 --- a/coremltools/converters/mil/mil/types/type_tensor.py +++ b/coremltools/converters/mil/mil/types/type_tensor.py @@ -94,18 +94,19 @@ def val(self, v): v_type = numpy_type_to_builtin_type(v.dtype) promoted_type = promote_types(v_type, primitive) - if v_type == primitive or v.dtype == np.dtype("O"): + primitive_np_type = nptype_from_builtin(primitive) + if v_type == primitive or v.dtype == np.dtype("O") or v.dtype == primitive_np_type: # np.array of symbolic has object type. Don't cast type. self._val = v elif promoted_type == primitive: - self._val = v.astype(nptype_from_builtin(primitive)) + self._val = v.astype(primitive_np_type) else: logger.warning( "Saving value type of {} into a builtin type of {}, might lose precision!".format( v.dtype, builtin_to_string(primitive) ) ) - self._val = v.astype(nptype_from_builtin(primitive)) + self._val = v.astype(primitive_np_type) tensor.__template_name__ = ( "tensor[" + primitive.__name__ + "," + ",".join(str(s) for s in shape) + "]" diff --git a/coremltools/converters/mil/mil/var.py b/coremltools/converters/mil/mil/var.py index ac4be0d12..dea8146dd 100644 --- a/coremltools/converters/mil/mil/var.py +++ b/coremltools/converters/mil/mil/var.py @@ -86,6 +86,7 @@ class Var: "_child_ops", "consuming_blocks", "_nonreplaceable_vars_upstream", + "is_descendant_of_const", ] def __init__( @@ -120,6 +121,17 @@ def __init__( self._nonreplaceable_vars_upstream = set() self._set_nonreplaceable_vars_upstream() + self._adjust_sym_val() + + # Track vars constness, which requires a var to satisfy one of the following: + # 1. var.val is not None, whichs mean the converter already has its compile time value through value inference. + # 2. Is a descendant of ``constexpr_`` ops. We don't compute the value inference of those ``constexpr_`` ops, + # due to the fact it can potentially results in memory issue. + self.is_descendant_of_const = Var._propagate_constness_upstream(self) + + def _adjust_sym_val(self): + pass + @property def nonreplaceable_vars_upstream(self): return self._nonreplaceable_vars_upstream @@ -136,6 +148,16 @@ def _is_nonreplaceable_var(var): return False return op.op_type.startswith("constexpr_") + @staticmethod + def _propagate_constness_upstream(var): + op = var.op + if op is None: + return False + if op.op_type.startswith("constexpr_") or var.val is not None: + return True + flattened_inputs = op.get_flattened_inputs() + return all([x.is_descendant_of_const for x in flattened_inputs]) + def _set_nonreplaceable_vars_upstream(self): """ A utility function to set the value of the "nonreplaceable_vars_upstream" property. diff --git a/coremltools/optimize/coreml/_post_training_quantization.py b/coremltools/optimize/coreml/_post_training_quantization.py index c61bf3c4e..64bb0616c 100644 --- a/coremltools/optimize/coreml/_post_training_quantization.py +++ b/coremltools/optimize/coreml/_post_training_quantization.py @@ -4,7 +4,7 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause from collections import OrderedDict -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional import numpy as np from attrs import define, field, validators @@ -12,7 +12,7 @@ from coremltools import _SPECIFICATION_VERSION_IOS_16 from coremltools.converters.mil.converter import mil_convert as _mil_convert -from coremltools.converters.mil.frontend.milproto.load import load as _milproto_to_pymil +from coremltools.converters.mil.frontend.milproto import load as _milproto_to_pymil from coremltools.converters.mil.mil.passes.defs.quantization import ( AbstractQuantizationPass as _AbstractQuantizationPass, ) @@ -25,9 +25,10 @@ from ._quantization_passes import palettize_weights as _palettize_weights from ._quantization_passes import prune_weights as _prune_weights -_DEFAULT_SPECIFICATION_VERSION_FOR_COMPRESSION = _SPECIFICATION_VERSION_IOS_16 -def _convert_model_spec_to_pymil_prog(mlmodel: _MLModel, specification_version: int): +def _convert_model_spec_to_pymil_prog( + mlmodel: _MLModel, specification_version: int, pymil_load_func: Callable +): """ An utility that converts a ml program model into PyMIL program. """ @@ -43,7 +44,7 @@ def _convert_model_spec_to_pymil_prog(mlmodel: _MLModel, specification_version: else: raise TypeError("weight compression not applicable for model type {}".format(model_type)) - prog = _milproto_to_pymil( + prog = pymil_load_func( model_spec=model_spec, specification_version=specification_version, file_weights_dir=mlmodel.weights_dir, @@ -51,14 +52,18 @@ def _convert_model_spec_to_pymil_prog(mlmodel: _MLModel, specification_version: return prog -def _apply_graph_pass(mlmodel: _MLModel, graph_pass: _AbstractQuantizationPass): +def _apply_graph_pass( + mlmodel: _MLModel, + graph_pass: _AbstractQuantizationPass, + spec_version: int = _SPECIFICATION_VERSION_IOS_16, + skip_model_load: bool = False, + pymil_load_func: Callable = _milproto_to_pymil.load, +): # Utility function which compresses a Core ML model # converts the full precision mlmodel into a pymil program model_spec = mlmodel.get_spec() - specification_version = max( - model_spec.specificationVersion, _DEFAULT_SPECIFICATION_VERSION_FOR_COMPRESSION - ) - prog = _convert_model_spec_to_pymil_prog(mlmodel, specification_version) + specification_version = max(model_spec.specificationVersion, spec_version) + prog = _convert_model_spec_to_pymil_prog(mlmodel, specification_version, pymil_load_func) # apply compression graph pass assert isinstance( @@ -74,6 +79,7 @@ def _apply_graph_pass(mlmodel: _MLModel, graph_pass: _AbstractQuantizationPass): specification_version=specification_version, compute_units=mlmodel.compute_unit, model_description=model_spec.description, + skip_model_load=skip_model_load, ) return compressed_mlmodel @@ -464,7 +470,8 @@ def _get_weight_metadata(op): ) return CoreMLWeightMetaData(op.val.val, child_ops=child_ops) - prog = _convert_model_spec_to_pymil_prog(mlmodel, mlmodel.get_spec().specificationVersion) + prog = _convert_model_spec_to_pymil_prog(mlmodel, mlmodel.get_spec().specificationVersion, + _milproto_to_pymil.load) res = _MetaDataDict({}) def get_weights_meta_block(block): diff --git a/coremltools/optimize/coreml/_quantization_passes.py b/coremltools/optimize/coreml/_quantization_passes.py index 55bce9272..79a1cc520 100644 --- a/coremltools/optimize/coreml/_quantization_passes.py +++ b/coremltools/optimize/coreml/_quantization_passes.py @@ -8,8 +8,11 @@ from coremltools import _logger as logger from coremltools.converters.mil.backend.mil.load import should_use_weight_file +from coremltools.converters.mil._deployment_compatibility import AvailableTarget from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import Operation, Program, types +from coremltools.converters.mil.mil.block import is_current_opset_version_compatible_with +from coremltools.converters.mil.mil.ops.defs._utils import pack_elements_into_bits from coremltools.converters.mil.mil.ops.defs.iOS16 import ( constexpr_affine_dequantize, constexpr_lut_to_dense, @@ -61,6 +64,8 @@ class AbstractCompressionPass(AbstractQuantizationPass): """ The abstract class for the compression graph passes. """ + _MINIMUM_OPSET_VERSION = AvailableTarget.iOS16 + def __init__(self, config: OptimizationConfig = None, fake_compression: bool = False): if not isinstance(config, (OptimizationConfig, type(None))): raise ValueError(f"config must be of type OptimizationConfig. Got {type(config)}.") @@ -80,6 +85,12 @@ def apply(self, prog): @block_context_manager def apply_block(block): + if not is_current_opset_version_compatible_with(self._MINIMUM_OPSET_VERSION): + logger.warning( + f"The program's opset is not compatible with {self._MINIMUM_OPSET_VERSION}. " + f"Skipped the compression pass {self.__class__}.") + return + valid_consts = [] for op in list(block.operations): for b in op.blocks: @@ -159,12 +170,12 @@ def get_supported_types_as_str(supported_type): class prune_weights(AbstractCompressionPass): """ This transform works for each ``const`` op if: - + - ``_is_deprecated=True`` and the ``op_selector`` returns ``True``. - ``_is_deprecated=False`` and the ``const`` value size ``> weight_threshold``. The transform performs the following: - + - The fraction of values with the least absolute value are zeroed out (self.sparsity). - If ``fake_compression=False``, the zeroed-out value is encoded using the ``constexpr_sparse_to_dense`` op. - If ``fake_compression=True``, the zeroed-out value is encoded using the ``const`` op. @@ -207,7 +218,7 @@ def _apply_block_sparsity(val, block_size, dim): assert rank in [2, 3, 4, 5], "block sparsity only supports weights of rank [2, 3, 4, 5]" """ Block sparsity follows these steps: - + 1. Input tensor with shape of ``[C_out, Cin, *K]``. 2. If ``dim = 1``, the tensor is transposed to ``[Cin, C_out, *K]``. The following example assumes ``dim = 0``. 3. Pad ``C_out`` so that it can be divided by ``block_size``: ``[C_out_pad, Cin, *K]``. @@ -406,7 +417,7 @@ class palettize_weights(AbstractCompressionPass): - ``_is_deprecated=False`` and the ``const`` value size ``> weight_threshold``. The transform performs the following: - + - A linear look-up table (LUT) with 2\ :sup:`nbits` entries is created with values represented by indexing into this LUT. - If ``fake_compression=False``, compressed value is encoded using the ``constexpr_lut_to_dense`` op. - If ``fake_compression=True``, compressed value is decompressed and then encoded using the ``const`` op. @@ -472,10 +483,6 @@ def compress_unique(val, nbits): indices = indices.astype(np.uint8) return lut, indices - def pack_indices_into_bytes_array(indices, nbits): - bitarray = np.unpackbits(indices.reshape(-1, 1), bitorder="little", axis=-1)[:, :nbits] - return np.packbits(bitarray.flatten(), bitorder="little") - def check_lut_parameters_are_valid(val, lut, indices): if not isinstance(lut, np.ndarray) or not isinstance(indices, np.ndarray): raise ValueError("LUT and indices must be type of numpy array.") @@ -518,7 +525,7 @@ def check_lut_parameters_are_valid(val, lut, indices): params = LutParams() params.lut = lut params.shape = val.shape - params.indices = pack_indices_into_bytes_array(indices, int(np.log2(lut.shape[0]))) + params.indices = pack_elements_into_bits(indices, int(np.log2(lut.shape[0]))) return params @staticmethod @@ -725,14 +732,10 @@ def __init__(self, op_selector): super().__init__(op_selector=op_selector) def is_valid_op(self, op): - return op.op_type in ( - "constexpr_affine_dequantize", - "constexpr_lut_to_dense", - "constexpr_sparse_to_dense", - ) + return op.op_type is not None and op.op_type.startswith("constexpr_") def transform_op(self, op): - decompressed_val = op.value_inference() + decompressed_val = op.materialized_val_inference() new_var = mb.const( val=decompressed_val, before_op=op, diff --git a/coremltools/test/ml_program/test_compression.py b/coremltools/test/ml_program/test_compression.py index 10ed10382..c08899458 100644 --- a/coremltools/test/ml_program/test_compression.py +++ b/coremltools/test/ml_program/test_compression.py @@ -4,7 +4,6 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause import numpy as np -import pytest import torch import coremltools as ct @@ -16,6 +15,7 @@ ) from coremltools.converters.mil.testing_utils import get_op_types_in_program + def get_test_model_and_data(multi_layer=False): inputs = [ct.TensorType(name="data", shape=(1, 64, 10, 10))] torch_input_values = [torch.rand(*i.shape.to_list()) for i in inputs] @@ -53,18 +53,18 @@ def test_op_selector(): mlmodel_no_quantized = affine_quantize_weights(mlmodel, mode="linear", op_selector=lambda const_op: const_op.val.val.size > 1e7) expected_ops = ['cast', 'conv', 'cast'] assert get_op_types_in_program(mlmodel_no_quantized._mil_program) == expected_ops - + @staticmethod def test_affine_quantize_weights_smoke(): model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data() torchmodel = torch.jit.trace(model, torch_input_values) mlmodel = ct.convert(torchmodel, inputs=inputs, convert_to="mlprogram") mlmodel_quantized = affine_quantize_weights(mlmodel, mode="linear_symmetric", dtype=np.int8) - + # validate parameters expected_ops = ['constexpr_affine_dequantize', 'cast', 'conv', 'cast'] assert get_op_types_in_program(mlmodel_quantized._mil_program) == expected_ops - + @staticmethod def test_palettize_weights_smoke(): model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data() @@ -75,7 +75,7 @@ def test_palettize_weights_smoke(): # validate parameters expected_ops = ['constexpr_lut_to_dense', 'cast', 'conv', 'cast'] assert get_op_types_in_program(mlmodel_palettized._mil_program) == expected_ops - + @staticmethod def test_sparsify_weights_threshold_smoke(): model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data() @@ -88,7 +88,7 @@ def test_sparsify_weights_threshold_smoke(): # validate parameters expected_ops = ['constexpr_sparse_to_dense', 'cast', 'conv', 'cast'] assert get_op_types_in_program(mlmodel_sparsified._mil_program) == expected_ops - + @staticmethod def test_sparsify_weights_percentile_smoke(): model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data() @@ -101,13 +101,13 @@ def test_sparsify_weights_percentile_smoke(): # validate parameters expected_ops = ['constexpr_sparse_to_dense', 'cast', 'conv', 'cast'] assert get_op_types_in_program(mlmodel_sparsified._mil_program) == expected_ops - + @staticmethod def test_weight_decompression_smoke(): model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data(multi_layer=True) torchmodel = torch.jit.trace(model, torch_input_values) mlmodel = ct.convert(torchmodel, inputs=inputs, convert_to="mlprogram") - + # we first compress the model mlmodel = palettize_weights(mlmodel, mode="kmeans", nbits=4, op_selector=lambda const_op: const_op.name == "conv_1_weight_to_fp16") mlmodel = affine_quantize_weights(mlmodel, mode="linear", op_selector=lambda const_op: const_op.name == "conv_2_weight_to_fp16") diff --git a/coremltools/test/optimize/coreml/test_passes.py b/coremltools/test/optimize/coreml/test_passes.py index 096866713..1c78cd453 100644 --- a/coremltools/test/optimize/coreml/test_passes.py +++ b/coremltools/test/optimize/coreml/test_passes.py @@ -17,7 +17,7 @@ import coremltools.optimize.coreml._quantization_passes as quantization from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import types -from coremltools.converters.mil.mil.passes.tests.test_passes import TestSkipConstexprOps +from coremltools.converters.mil.mil.passes.tests.test_passes import CONSTEXPR_FUNCS, CONSTEXPR_OPS from coremltools.converters.mil.testing_utils import get_op_types_in_program @@ -774,7 +774,7 @@ def prog(x): @staticmethod @pytest.mark.parametrize( "constexpr_op", - TestSkipConstexprOps.CONSTEXPR_OPS, + CONSTEXPR_OPS, ) def test_constexpr_const_not_compressed(constexpr_op): """ @@ -782,7 +782,7 @@ def test_constexpr_const_not_compressed(constexpr_op): """ @mb.program(input_specs=[mb.TensorSpec(shape=(2, 3, 4, 5))]) def prog(x): - constexpr = TestSkipConstexprOps.CONSTEXPR_FUNCS[constexpr_op]((2, 3, 4, 5)) + constexpr = CONSTEXPR_FUNCS[constexpr_op]((2, 3, 4, 5)) return mb.add(x=x, y=constexpr) compressor = quantization.palettize_weights( diff --git a/coremltools/test/optimize/coreml/test_post_training_quantization.py b/coremltools/test/optimize/coreml/test_post_training_quantization.py index 259b78b92..7fb842bfb 100644 --- a/coremltools/test/optimize/coreml/test_post_training_quantization.py +++ b/coremltools/test/optimize/coreml/test_post_training_quantization.py @@ -227,7 +227,7 @@ def test_palettization(): config.set_global(global_config) config.set_op_type("conv", conv_config) config.set_op_name("conv_2_1", conv_2_config) - config.set_op_name("input_5", linear_1_config) + config.set_op_name("linear_0", linear_1_config) mlmodel = cto.coreml.palettize_weights(mlmodel, config) expected_ops = [ @@ -493,7 +493,7 @@ def test_pruning(): config.set_global(global_config) config.set_op_type("lstm", None) - config.set_op_name("input_5", None) + config.set_op_name("linear_0", None) mlmodel = cto.coreml.prune_weights(mlmodel, config) expected_ops = [ diff --git a/coremltools/version.py b/coremltools/version.py index c00da273a..27e3a7666 100644 --- a/coremltools/version.py +++ b/coremltools/version.py @@ -4,4 +4,4 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -__version__ = "7.0b2" # VERSION_STRING +__version__ = "7.0" # VERSION_STRING diff --git a/docs-guides/make.bat b/docs-guides/make.bat index 32bb24529..954237b9b 100644 --- a/docs-guides/make.bat +++ b/docs-guides/make.bat @@ -1,35 +1,35 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=. -set BUILDDIR=_build - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.https://www.sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "" goto help - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/milstoragepython/MilStoragePython.cpp b/milstoragepython/MilStoragePython.cpp index 62d4b18d8..3d70e5c4b 100644 --- a/milstoragepython/MilStoragePython.cpp +++ b/milstoragepython/MilStoragePython.cpp @@ -31,7 +31,7 @@ using namespace CoreML::MilStoragePython; PYBIND11_PLUGIN(libmilstoragepython) { py::module m("libmilstoragepython", "Library to create, access and edit CoreML blob files."); - py::class_ blobStorageWriter(m, "_BlobStorageWriter"); + py::class_ blobStorageWriter(m, "_BlobStorageWriter", py::module_local()); blobStorageWriter.def(py::init(), py::arg("file_name"), py::arg("truncate_file") = true) .def("write_int8_data", &MilStoragePythonWriter::write_int8_data) .def("write_uint8_data", &MilStoragePythonWriter::write_uint8_data) @@ -40,7 +40,7 @@ PYBIND11_PLUGIN(libmilstoragepython) { .def("write_fp16_data", &MilStoragePythonWriter::write_fp16_data) .def("write_float_data", &MilStoragePythonWriter::write_float_data); - py::class_ blobStorageReader(m, "_BlobStorageReader"); + py::class_ blobStorageReader(m, "_BlobStorageReader", py::module_local()); blobStorageReader.def(py::init()) .def("read_int8_data", &MilStoragePythonReader::read_int8_data) .def("read_uint8_data", &MilStoragePythonReader::read_uint8_data) diff --git a/mlmodel/docs/Format/ItemSimilarityRecommender.rst b/mlmodel/docs/Format/ItemSimilarityRecommender.rst index d213bc4c7..7a63d4323 100644 --- a/mlmodel/docs/Format/ItemSimilarityRecommender.rst +++ b/mlmodel/docs/Format/ItemSimilarityRecommender.rst @@ -8,40 +8,40 @@ items but not part of that item set. The predicted score for a given item k is as follows: - | sum_(i in observed items) + | sum_(i in observed items) | sim_(k,i) * (score_i - shift_k) Because only the most similar scores for each item i are stored, ``sim_(k,i)`` is often zero. -For many models, the score adjustment parameter ``shift_j`` is zero -- +For many models, the score adjustment parameter ``shift_j`` is zero -- it's occasionally used to counteract global biases for popular items. .. code-block:: proto message ItemSimilarityRecommender { - + message ConnectedItem { uint64 itemId = 1; double similarityScore = 2; } - + message SimilarItems { uint64 itemId = 1; repeated ConnectedItem similarItemList = 2; double itemScoreAdjustment = 3; } - + repeated SimilarItems itemItemSimilarities = 1; - + StringVector itemStringIds = 2; Int64Vector itemInt64Ids = 3; - - + + string recommendedItemListOutputFeatureName = 20; string recommendedItemScoreOutputFeatureName = 21; - + }