From 87b40fbb6a0e91de3d6142f7ca74d8c6fb2c4893 Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Mon, 19 Jun 2023 13:32:13 +0200 Subject: [PATCH 1/5] working implementation --- setup.py | 8 +- src/sparseml/exporters/kv_cache_injector.py | 156 ++++ src/sparseml/exporters/transforms/__init__.py | 1 + .../exporters/transforms/kv_cache/__init__.py | 25 + .../kv_cache/cache_keys_and_values.py | 744 ++++++++++++++++++ .../exporters/transforms/kv_cache/configs.py | 167 ++++ .../kv_cache/positions_adjustment_base.py | 48 ++ .../kv_cache/positions_adjustment_codegen.py | 89 +++ .../kv_cache/positions_adjustment_opt.py | 133 ++++ ...atmul_add_to_matmulinteger_add_cast_mul.py | 32 +- .../matmul_to_matmulinteger_cast_mul.py | 5 +- .../add_quantized_conv_matmul_add_ops.py | 105 ++- src/sparseml/onnx/utils/graph_editor.py | 62 +- src/sparseml/pytorch/base.py | 2 +- src/sparseml/transformers/__init__.py | 2 +- ...atmul_add_to_matmulinteger_add_cast_mul.py | 17 +- tests/sparseml/pytorch/model.onnx | Bin 0 -> 23435 bytes 17 files changed, 1533 insertions(+), 63 deletions(-) create mode 100644 src/sparseml/exporters/kv_cache_injector.py create mode 100644 src/sparseml/exporters/transforms/kv_cache/__init__.py create mode 100644 src/sparseml/exporters/transforms/kv_cache/cache_keys_and_values.py create mode 100644 src/sparseml/exporters/transforms/kv_cache/configs.py create mode 100644 src/sparseml/exporters/transforms/kv_cache/positions_adjustment_base.py create mode 100644 src/sparseml/exporters/transforms/kv_cache/positions_adjustment_codegen.py create mode 100644 src/sparseml/exporters/transforms/kv_cache/positions_adjustment_opt.py create mode 100644 tests/sparseml/pytorch/model.onnx diff --git a/setup.py b/setup.py index ed689439f9a..cf38c8b3d0e 100644 --- a/setup.py +++ b/setup.py @@ -62,17 +62,17 @@ _deepsparse_ent_deps = [f"deepsparse-ent~={version_nm_deps}"] _onnxruntime_deps = ["onnxruntime>=1.0.0"] -supported_torch_version = "torch>=1.7.0,<1.14" +supported_torch_version = "torch>=1.7.0,<=2.0" _pytorch_deps = [ supported_torch_version, "gputils", ] _pytorch_all_deps = _pytorch_deps + [ - "torchvision>=0.3.0,<0.15", - "torchaudio<=0.13", + "torchvision>=0.3.0,<=0.15.1", + "torchaudio<=2.0.1", ] _pytorch_vision_deps = _pytorch_deps + [ - "torchvision>=0.3.0,<0.15", + "torchvision>=0.3.0,<=0.15.1", "opencv-python<=4.6.0.66", ] _transformers_deps = _pytorch_deps + [ diff --git a/src/sparseml/exporters/kv_cache_injector.py b/src/sparseml/exporters/kv_cache_injector.py new file mode 100644 index 00000000000..e088aeb0e7c --- /dev/null +++ b/src/sparseml/exporters/kv_cache_injector.py @@ -0,0 +1,156 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from copy import deepcopy +from pathlib import Path +from typing import Any, Optional, Union + +import onnx + +from sparseml.exporters.base_exporter import BaseExporter +from sparseml.exporters.transforms.kv_cache import ( + CacheKeysAndValues, + get_kv_cache_config, +) +from sparsezoo.utils import save_onnx + + +_LOGGER = logging.getLogger(__name__) + + +class KeyValueCacheInjector(BaseExporter): + def __init__( + self, + model_path: Optional[str] = None, + inplace: bool = True, + **kwargs: Any, + ): + """ + A transformation that injects Key Value cache support into the model. + This means that the + - autoregressive model that + * takes input_ids and attention_mask as INPUT + * returns logits as OUTPUT + - is transformed into a model that + * takes input_ids, attention_mask, and kv_cache as INPUT + * returns logits and updated kv_cache as OUTPUT + + The goal of the KV cache injection is speed up the autoregressive + generation process (reduce the compute of key/value pairs by storing + the results of previous computations in memory). + + The exporter will look for a `config.json` file in the `model_path` + directory to determine the parameters for KV cache injection. + If `model_path` is not provided, the requested parameters can be + provided in the `kwargs`. + + This transformation not only solely injects the kv cache + inputs/outputs, but also adjusts the original ONNX graph to + account for the necessary changes. This involves e.g. adding + the 'position' input to the model, so that the positional + embeddings of the new model are compatible with the past kv + cache information. + + Usage: + ```python + onnx_model: onnx.ModelProto = ... + exporter = KeyValueCacheInjector(model_path="path/to/model") + exporter.export(onnx_model, "model.onnx") + ``` + + Alternatively: + ```python + onnx_model: onnx.ModelProto = ... + exporter = KeyValueCacheInjector(model_path="path/to/model") + exporter = KeyValueCacheInjector(num_attention_heads = 16, + hidden_size_dim = 64) + exporter.export(onnx_model, "model.onnx") + ``` + + You can also just optimize the model directly without saving to disk: + ```python + onnx_model: onnx.ModelProto = ... + exporter = KeyValueCacheInjector(model_path="path/to/model") + optimized_model = exporter.apply(onnx_model) + ``` + + :param model_path: The path to the directory containing the model. + :param inplace: If True, the model will be modified in place. + If False, a copy of the model will be made and modified. + :param kwargs: (Optionally) the parameters for the KV cache injection + if no `model_path` is provided. + """ + self.inplace = inplace + self.config = get_kv_cache_config(model_path) + + if model_path is not None: + # get the parameters from the config + self.config = get_kv_cache_config(model_path) + + num_attention_heads = self.config.num_attention_heads + hidden_size_kv_cache_dim = self.config.hidden_size_kv_cache + multiply_batch_by_num_att_heads = ( + self.config.multiply_batch_by_num_att_heads + ) + transpose_value_input = self.config.transpose_value_input + transpose_key_input = self.config.transpose_key_input + positions_adjustment = self.config.positions_adjustment_transform + + elif kwargs: + # get the parameters from the kwargs + num_attention_heads = kwargs.get("num_attention_heads") + hidden_size_kv_cache_dim = kwargs.get("hidden_size_kv_cache_dim") + multiply_batch_by_num_att_heads = kwargs.get( + "multiply_batch_by_num_att_heads", False + ) + transpose_value_input = kwargs.get("transpose_value_input") + transpose_key_input = kwargs.get("transpose_key_input") + positions_adjustment = None + + else: + raise ValueError( + "Either `model_path` or kwargs must be provided to " + "KeyValueCacheInjector" + ) + + transforms = [ + CacheKeysAndValues( + num_attention_heads=num_attention_heads, + hidden_size_kv_cache=hidden_size_kv_cache_dim, + multiply_batch_by_num_att_heads=multiply_batch_by_num_att_heads, + transpose_value_input=transpose_value_input, + transpose_key_input=transpose_key_input, + ) + ] + if positions_adjustment is not None: + transforms += [positions_adjustment()] + super().__init__(transforms) + + def pre_validate(self, model: Union[onnx.ModelProto, str, Path]) -> onnx.ModelProto: + if isinstance(model, (str, Path)): + model = onnx.load(str(model)) + + if not isinstance(model, onnx.ModelProto): + raise TypeError(f"Expected onnx.ModelProto, found {type(model)}") + return model if self.inplace else deepcopy(model) + + def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto: + if not isinstance(model, onnx.ModelProto): + raise TypeError(f"Expected onnx.ModelProto, found {type(model)}") + return model + + def export(self, pre_transforms_model: onnx.ModelProto, file_path: str): + post_transforms_model: onnx.ModelProto = self.apply(pre_transforms_model) + save_onnx(post_transforms_model, file_path) diff --git a/src/sparseml/exporters/transforms/__init__.py b/src/sparseml/exporters/transforms/__init__.py index 5b0e877df9f..459c86083e4 100644 --- a/src/sparseml/exporters/transforms/__init__.py +++ b/src/sparseml/exporters/transforms/__init__.py @@ -46,3 +46,4 @@ from .remove_duplicate_qconv_weights import RemoveDuplicateQConvWeights from .remove_duplicate_quantize_ops import RemoveDuplicateQuantizeOps from .skip_input_quantize import SkipInputQuantize +from .kv_cache import * diff --git a/src/sparseml/exporters/transforms/kv_cache/__init__.py b/src/sparseml/exporters/transforms/kv_cache/__init__.py new file mode 100644 index 00000000000..8ccdc3e6086 --- /dev/null +++ b/src/sparseml/exporters/transforms/kv_cache/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Transforms for adding KV caching mechanism into language models +""" + +# flake8: noqa +# isort:skip_file + +from .cache_keys_and_values import * +from .positions_adjustment_base import * +from .positions_adjustment_opt import * +from .positions_adjustment_codegen import * +from .configs import * diff --git a/src/sparseml/exporters/transforms/kv_cache/cache_keys_and_values.py b/src/sparseml/exporters/transforms/kv_cache/cache_keys_and_values.py new file mode 100644 index 00000000000..d7a9c92b562 --- /dev/null +++ b/src/sparseml/exporters/transforms/kv_cache/cache_keys_and_values.py @@ -0,0 +1,744 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +from typing import List, Optional, Set, Tuple, Union + +import numpy +import onnx +from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, numpy_helper + +from sparseml.exporters.transforms.onnx_transform import OnnxTransform +from sparseml.onnx.utils import ONNXGraph + + +__all__ = ["CacheKeysAndValues"] + + +_LOGGER = logging.getLogger(__name__) + +ALLOWED_NODES_BEFORE_SOFTMAX = ["Cast"] +OUTPUT_CACHE_NAME = """present.{attention_layer_idx}.{cache_type}""" +INPUT_CACHE_NAME = """past_key_values.{attention_layer_idx}.{cache_type}""" + + +class CacheKeysAndValues(OnnxTransform): + """ + Inject the key and value caches into the graph for the attention layers. + The logic for pattern matching is as follows: + + 1. Find all the MatMuls that are preceded by a Softmax operation. + Those are the MatMuls that perform V x Softmax(Q x K^T) operation + (the "value" MatMuls). + 2. Given the "value" MatMuls found in step 1, perform a Breadth First Search + to find the "key" MatMuls that perform Q x K^T operation. + 3. Before each pair of "key" and "value" MatMuls, inject a cache node that + concatenates the current keys/values with the cached keys/values. + 4. For the key or value cache, if there is a Transpose node present before + the MatMul, the concatenation will be performed before the Transpose node + 6. (Optional) To account for the variance in the operations in the vicinity + of the "key" and "value" MatMuls, the user can specify whether to additionally + inject a Reshape or Transpose node, so that the dimensions of the cache + inputs/outputs are compatible with the values they are concatenated with. + + This transform also sets the subset of kv cache inputs/outputs dimensions ( + num_attention_heads and hidden_size_kv_cache) to the appropriate static values. + + :param num_attention_heads: number of attention heads of the model + :param hidden_size_kv_cache: hidden size of the key and value cache + :param multiply_batch_by_num_att_heads: every model created by + this transformation, will have kv_cache inputs/outputs that have dimensions: + [`batch_size`,`num_attention_heads`,`past_sequence_len`,`hidden_size_kv_cache`] + + However, internally, there may be a need of reshaping the kv_cache + inputs/outputs ("merging" the `batch_size` and `num_attention_heads` + dimensions) so that it is compatible with the values it is concatenated with. + If True, the batch size will be multiplied by the number of attention + heads just before the appropriate concatenation node (as reflected by + the "Reshape" nodes in the diagram below). + :param transpose_value_input: if not None, transpose the kv cache value + input before the "value" MatMul. The argument needs to be a tuple of + 4 integers that represent the permutation of the input dimensions. + This will insert a Transpose node before the "value" MatMul. If + multiply_batch_by_num_att_heads is True, the Transpose node will be + inserted before the Reshape node. + :param transpose_key_input: works analogously to transpose_value_input, + but for the key input. + + + Transforms + ``` + | + | Key + | | Query + | Transpose | + | | | + | | | + | | + | "key" MatMul + | | + | ... Value + | | | + | Softmax | + | | | + | ... ... + | | | + | | + | "value" MatMul + | | + | ... + ``` + to + + ``` + | + | KeyCache + | | + | Transpose + |(optional) + | | + | | Key + | Reshape | + |(optional) | + | | | + | | | + | | | + | | + | Concat ------------> OutputKeyCache + | | + | | Query + | ... | + | | | + | | | + | | | ValueCache + | | | | + | | | Transpose + | | | (optional) + | | | | + | | | Reshape + | | | (optional) + | | | Value | + | | | | | + | | | | + | "key" MatMul | | + | | | + | ... Concat --> OutputValueCache + | | | + | Softmax | + | | ... + | ... | + | | | + | | + | "value" MatMul + | | + | ... + ``` + + """ + + def __init__( + self, + num_attention_heads: int, + hidden_size_kv_cache: int, + multiply_batch_by_num_att_heads: bool, + transpose_value_input: Optional[Tuple[int, int, int, int]] = None, + transpose_key_input: Optional[Tuple[int, int, int, int]] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.hidden_size_kv_cache = hidden_size_kv_cache + self.multiply_batch_by_num_att_heads = multiply_batch_by_num_att_heads + self.transpose_value_input = transpose_value_input + self.transpose_key_input = transpose_key_input + + def transform(self, model: ModelProto) -> ModelProto: + graph = ONNXGraph(model) + + key_value_matmul_pairs = _find_key_value_matmul_pairs(graph) + + # Inject kv cache to the graph as the model input, + # Inject kv cache concatenated with the current keys/values as the output + inputs_to_add = [] + outputs_to_add = [] + + # get default int8 type to use if graph is quantized + use_uint8_if_quantized = _use_uint8_if_quantized(graph) + + for idx, (key_matmul, value_matmul) in enumerate(key_value_matmul_pairs): + + value_input_idx = _value_input_idx(value_matmul, model) + + key_concat_node, key_input_tensor, key_output_tensor = create_cache( + model=model, + node=key_matmul, + cache_input_idx=1, + cache_input_name=INPUT_CACHE_NAME.format( + attention_layer_idx=idx, cache_type="key" + ), + cache_output_name=OUTPUT_CACHE_NAME.format( + attention_layer_idx=idx, cache_type="key" + ), + use_uint8_if_quantized=use_uint8_if_quantized, + num_attention_heads=self.num_attention_heads, + hidden_size_kv_cache=self.hidden_size_kv_cache, + transpose_input=self.transpose_key_input, + multiply_batch_by_num_att_heads=self.multiply_batch_by_num_att_heads, # noqa E501 + ) + value_concat_node, value_input_tensor, value_output_tensor = create_cache( + model=model, + node=value_matmul, + cache_input_idx=value_input_idx, + cache_input_name=INPUT_CACHE_NAME.format( + attention_layer_idx=idx, cache_type="value" + ), + cache_output_name=OUTPUT_CACHE_NAME.format( + attention_layer_idx=idx, cache_type="value" + ), + use_uint8_if_quantized=use_uint8_if_quantized, + num_attention_heads=self.num_attention_heads, + hidden_size_kv_cache=self.hidden_size_kv_cache, + transpose_input=self.transpose_value_input, + multiply_batch_by_num_att_heads=self.multiply_batch_by_num_att_heads, # noqa E501 + ) + + inputs_to_add.extend([key_input_tensor, value_input_tensor]) + outputs_to_add.extend([key_output_tensor, value_output_tensor]) + + self.log_match(key_matmul) + self.log_match(value_matmul) + + # update model with cache inputs, and outputs + model.graph.input.extend(inputs_to_add) + model.graph.output.extend(outputs_to_add) + + _set_attention_mask_to_dynamic(model) + + return model + + +def create_cache( + model: ModelProto, + node: NodeProto, + cache_input_idx: int, + cache_input_name: str, + cache_output_name: str, + num_attention_heads: int, + hidden_size_kv_cache: int, + use_uint8_if_quantized: bool = True, + batch_size: int = 1, + multiply_batch_by_num_att_heads: bool = True, + transpose_input: Optional[Tuple[int, int, int, int]] = None, +) -> Tuple[NodeProto, ValueInfoProto, ValueInfoProto]: + """ + Injects a cache (value or key) into the graph for a given Matmul node. + + :param model: Model to update + :param node: MatMul node that follows the cache injection point + :param cache_input_idx: Index of the input + (where the cache will be injected) to the MatMul + :param cache_input_name: Name of cache input + :param cache_output_name: Name of cache output + :param num_attention_heads: number of attention heads of the model + :param hidden_size_kv_cache: hidden size of the key/value cache + :param use_uint8_if_quantized: True if quantized MatMuls should have uint8 + inputs, if False, uses int8 + :param batch_size: batch size of the kv cache. By default, this is 1. + :param multiply_batch_by_num_att_heads: If True, the batch size of the + kv cache is multiplied by the number of attention heads before the + concat node. + :param transpose_input: If not None, transpose the input to the cache + before the concat node. If `multiply_batch_by_num_att_heads` is True, + the transpose is applied after the batch size is multiplied by the + number of attention heads. + :return: tuple of concat node to add, cache input to add, and cache output to add, + updates existing nodes in-place + """ + CACHE_INPUT_DIMS = [ + batch_size, + num_attention_heads, + "past_sequence_len", + hidden_size_kv_cache, + ] + CACHE_OUTPUT_DIMS = [ + batch_size, + num_attention_heads, + "past_sequence_len + 1", + hidden_size_kv_cache, + ] + + graph = ONNXGraph(model) + + cache_data_type = ( + TensorProto.FLOAT + if node.op_type not in ["MatMulInteger", "QLinearMatMul"] + else TensorProto.UINT8 + if use_uint8_if_quantized + else TensorProto.INT8 + ) + + # create graph input info proto + cache_input_info = onnx.helper.make_tensor_value_info( + cache_input_name, + cache_data_type, + CACHE_INPUT_DIMS, + ) + # create graph output info proto + cache_output_info = onnx.helper.make_tensor_value_info( + cache_output_name, + cache_data_type, + CACHE_OUTPUT_DIMS, + ) + + if node.op_type == "QLinearMatMul" and cache_input_idx == 1: + cache_input_idx = 3 # QLinearMatMul B matrix is at idx 3, not 1 + + cache_parent = graph.get_node_single_parent(node, index=cache_input_idx) + if isinstance(cache_parent, NodeProto) and cache_parent.op_type == "Transpose": + # move cache to before a transpose if applicable + # this is due to pytorch operations potentially extracting shape values + # from the key tensor before the transpose is applied + pre_cache_input_id = cache_parent.input[0] + # update concat axis + node = cache_parent + else: + pre_cache_input_id = node.input[cache_input_idx] + + cache_input_name_concat = cache_input_name + cache_output_name_concat = cache_output_name + cache_input_dims_concat = CACHE_INPUT_DIMS + + if transpose_input: + ( + graph, + cache_input_dims_concat, + cache_input_name_concat, + cache_output_name_concat, + ) = transpose_kv_cache_inputs_outputs( + graph=graph, + cache_input_name=cache_input_name_concat, + cache_output_name=cache_output_name_concat, + cache_input_dims=cache_input_dims_concat, + transpose_input=transpose_input, + ) + + if multiply_batch_by_num_att_heads: + ( + model, + cache_input_dims_concat, + cache_input_name_concat, + cache_output_name_concat, + ) = reshape_kv_cache_inputs_outputs( + model=model, + cache_input_name=cache_input_name_concat, + cache_output_name=cache_output_name_concat, + cache_input_dims=cache_input_dims_concat, + batch_size=batch_size, + num_attention_heads=num_attention_heads, + ) + + concat_axis = [ + idx + for (idx, dim) in enumerate(cache_input_dims_concat) + if dim == "past_sequence_len" + ][0] + + concat_node = onnx.helper.make_node( + op_type="Concat", + inputs=[cache_input_name_concat, pre_cache_input_id], + outputs=[cache_output_name_concat], + axis=concat_axis, + name=f"concat.{cache_input_name_concat}", + ) + + for node in model.graph.node: + for input_idx, input_id in enumerate(node.input): + if input_id == pre_cache_input_id and node.name != concat_node.name: + node.input[input_idx] = cache_output_name_concat + + graph.add_node(concat_node) + + return concat_node, cache_input_info, cache_output_info + + +def reshape_kv_cache_inputs_outputs( + model: ModelProto, + cache_input_name: str, + cache_output_name: str, + cache_input_dims: List[Union[int, str]], + batch_size: int, + num_attention_heads: int, +) -> Tuple[ModelProto, List[Union[int, str]], str, str]: + """ + Reshapes the input and output of a kv cache in the model, so that the dimensions + `batch_size` and `num_attention_heads` are multiplied together. + + Transform: + ``` + | cache_input_name + | | + | ... + | | + | cache_output_name + ``` + to: + ``` + | cache_input_name + | | + | cache_input_name_reshaped + | | + | ... + | | + | cache_output_name_reshaped + | | + | cache_output_name + + + :param model: The model to update + :param cache_input_name: The name of the input to the submodel + :param cache_output_name: The name of the output from the submodel + :param cache_input_dims: The dimensions of the input to the submodel + :param batch_size: The batch size of the model + :param num_attention_heads: The number of attention heads in the model + :return: The updated model, the updated input dimensions, + the updated input name, and the updated output name + """ + + cache_input_name_reshaped = f"{cache_input_name}_reshaped" + cache_output_name_reshaped = f"{cache_output_name}_reshaped" + + reshape_in_initializer_name = f"reshape_in.{cache_input_name}" + reshape_out_initializer_name = f"reshape_out.{cache_output_name}" + + reshape_input_dims_in = copy.deepcopy(cache_input_dims) + reshape_input_dims_out = copy.deepcopy(cache_input_dims) + + # "squash" the batch_size and num_attention_heads dimensions together + reshape_input_dims_in[0] = batch_size * num_attention_heads + reshape_input_dims_in.remove(num_attention_heads) + + reshape_in_array = numpy.array( + [dim if isinstance(dim, int) else -1 for dim in reshape_input_dims_in], + dtype=numpy.int64, + ) + reshape_out_array = numpy.array( + [dim if isinstance(dim, int) else -1 for dim in reshape_input_dims_out], + dtype=numpy.int64, + ) + + reshape_in_initializer = numpy_helper.from_array( + numpy.array( + reshape_in_array, + dtype=numpy.int64, + ), + reshape_in_initializer_name, + ) + + reshape_out_initializer = numpy_helper.from_array( + numpy.array( + reshape_out_array, + dtype=numpy.int64, + ), + reshape_out_initializer_name, + ) + + reshape_node_in = onnx.helper.make_node( + op_type="Reshape", + inputs=[cache_input_name, reshape_in_initializer_name], + outputs=[cache_input_name_reshaped], + name=f"reshape.{cache_input_name}", + ) + + reshape_node_out = onnx.helper.make_node( + op_type="Reshape", + inputs=[cache_output_name_reshaped, reshape_out_initializer_name], + outputs=[cache_output_name], + name=f"reshape.{cache_output_name}", + ) + graph = ONNXGraph(model) + + graph.add_node(reshape_node_in) + graph.add_node(reshape_node_out) + + model.graph.initializer.extend([reshape_in_initializer, reshape_out_initializer]) + + return ( + model, + reshape_input_dims_in, + cache_input_name_reshaped, + cache_output_name_reshaped, + ) + + +def transpose_kv_cache_inputs_outputs( + graph: ONNXGraph, + cache_input_name: str, + cache_output_name: str, + cache_input_dims: List[Union[int, str]], + transpose_input: Tuple[int, int, int, int], +) -> Tuple[ModelProto, List[Union[int, str]], str, str]: + """ + Transposes the input and output of a kv cache in the model + according to the transpose_input sequence + + Transform: + ``` + | cache_input_name + | | + | ... + | | + | cache_output_name + ``` + to: + ``` + | cache_input_name + | | + | cache_input_name_transposed + | | + | ... + | | + | cache_output_name_transposed + | | + | cache_output_name + + :param graph: The graph to update + :param cache_input_name: The name of the input to the submodel + :param cache_output_name: The name of the output from the submodel + :param transpose_input: The permutation of the input dimensions + :param cache_input_dims: The dimensions of the input to the submodel + :return: The updated model, the updated input dimensions, + the updated input name, and the updated output name + """ + + cache_input_name_transposed = f"{cache_input_name}_transposed" + cache_output_name_transposed = f"{cache_output_name}_transposed" + + transpose_node_in = onnx.helper.make_node( + op_type="Transpose", + inputs=[cache_input_name], + outputs=[cache_input_name_transposed], + name=f"transpose.{cache_input_name}", + perm=transpose_input, + ) + transpose_node_out = onnx.helper.make_node( + op_type="Transpose", + inputs=[cache_output_name_transposed], + outputs=[cache_output_name], + name=f"transpose.{cache_output_name}", + perm=transpose_input, + ) + transposed_input_dims = [cache_input_dims[i] for i in transpose_input] + + graph.add_node(transpose_node_in) + graph.add_node(transpose_node_out) + + return ( + graph, + transposed_input_dims, + cache_input_name_transposed, + cache_output_name_transposed, + ) + + +def _find_key_value_matmul_pairs( + graph: ONNXGraph, +) -> List[Tuple[NodeProto, NodeProto]]: + # Find pairs of "key" and "value" MatMuls. + # Each attention block contains a pair of MatMuls: + # - key MatMul that computes Q x K^T + # - value MatMul that computes Softmax(Q x K^T) x V + # The function returns: + # [(key_matmul_0, value_matmul_0), (key_matmul_1, value_matmul_1), ...] + + key_value_matmul_pairs = [] + value_matmuls = [node for node in graph.nodes if is_value_matmul(node, graph)] + value_matmul_names = {node.name for node in value_matmuls} + + # for every value matmul, find the corresponding key matmul + for value_matmul in value_matmuls: + key_matmul = _find_key_matmul_from_value_matmul( + value_matmul, graph, value_matmul_names + ) + if key_matmul is not None: + key_value_matmul_pairs.append((key_matmul, value_matmul)) + else: + raise RuntimeError( + f"Could not find key matmul for value matmul {value_matmul.name}" + ) + + return key_value_matmul_pairs + + +def is_value_matmul( + node: NodeProto, + graph: ONNXGraph, + allowed_nodes_before_softmax: Set[str] = ALLOWED_NODES_BEFORE_SOFTMAX, +) -> bool: + """ + Returns True if the node is a "value" MatMul, i.e. a MatMul that meets + the following criteria: + - is_matmul(node) is True + - node has no parameters + - node has a single `Softmax` parent node + or + the parent node `Softmax` is preceded by + a set of nodes that are specified in the + `allowed_nodes_before_softmax` set + + :param node: node to check + :param graph: graph containing the node + :param allowed_nodes_before_softmax: set of node types that are allowed + to be located between the node in question a Softmax node, so that + the node can still be considered a "value" MatMul + """ + + if not is_matmul(node) or _is_parameterized_matmul(node, graph): + # not a matmul or MatMul op has a parameter + return False + + parent = graph.get_node_single_parent(node, index=0) + while parent.op_type in allowed_nodes_before_softmax: + if not isinstance(parent, NodeProto): + break + parent = graph.get_node_single_parent(parent, index=0) + if parent is None: + raise ValueError( + "While traversing the graph to find a Softmax that precedes " + f"the candidate for a `value` MatMul: {node.name}, found a node " + f"with multiple parents {parent.name}. " + "It is assumed that the graph that connects the Softmax " + "node and the `value` MatMul node is a linear chain of nodes " + "and thus none of the encountered nodes should have multiple " + "parents" + ) + + if parent.op_type == "Softmax": + # a parent is a Softmax node, assume this is a "value" MatMul + return True + + # no parents are a softmax node + return False + + +def _find_key_matmul_from_value_matmul( + value_matmul: NodeProto, + graph: ONNXGraph, + value_matmul_names: Set[str], +) -> Optional[NodeProto]: + # Perform a BFS up the model DAG from the "value" MatMul until + # we find the corresponding "key" MatMul. + # The "key" MatMul is assumed to be the first non-parameterized + # MatMul we reach during the search. + # We return None if no such matmul is found, or there is an indication that + # we have traversed outside the self attention module (found another + # "value" MatMul) + + seen_node_names = {value_matmul.name} + node_queue = [value_matmul] + + while node_queue: + current_node = node_queue.pop(0) + node_parents = graph.get_node_parents(current_node) + + if ( + is_matmul(current_node) + and (current_node.name != value_matmul.name) + and not _is_parameterized_matmul(current_node, graph) + ): + # treat root node as regular, non MatMul node + if current_node.name in value_matmul_names: + _LOGGER.info( + f"First MatMul node found for value matmul {value_matmul.name} " + f"was another value matmul {current_node.name}", + ) + return None + else: + # Success case - + # first found matmul is non-parameterized + return current_node + + for parent in node_parents: + if not isinstance(parent, NodeProto): + continue + if parent.name not in seen_node_names: + seen_node_names.add(parent.name) + node_queue.append(parent) + + # No MatMul matched before bottoming + _LOGGER.info( + f"No key matmul found for value matmul {value_matmul.name}", + ) + return None + + +def _value_input_idx(value_matmul: NodeProto, model: ModelProto) -> int: + graph = ONNXGraph(model) + # get idx of matmul that the value node is an input of + if len(value_matmul.input) != 2: + raise ValueError( + f"Expected value matmul to have 2 inputs, got {len(value_matmul.input)}" + ) + softmax_input_idx = 0 # default to softmax being on left hand side + for idx, parent in enumerate(graph.get_node_parents(value_matmul)): + if isinstance(parent, NodeProto) and parent.op_type == "Softmax": + softmax_input_idx = idx + break + return 1 - softmax_input_idx # return index that isn't the softmax + + +def is_matmul(node: NodeProto) -> bool: + # matches against FP32 or INT8 matmul types + return node.op_type in ["MatMul", "MatMulInteger", "Gemm", "QLinearMatMul"] + + +def _is_parameterized_matmul(node: NodeProto, graph: ONNXGraph) -> bool: + # returns True if any matrix input to the node is a parameter + # (initializer) of the graph + + # QLinearMatMul has the A,B matrices in different indices + matrix_indices = (0, 1) if node.op_type != "QLinearMatMul" else (0, 3) + + for idx in matrix_indices: + if graph.get_init_by_name(node.input[idx]): + return True # matrix input is a model weight + return False + + +def _use_uint8_if_quantized(graph: ONNXGraph) -> bool: + use_uint8_if_quantized = True # default to True + quantize_nodes = [node for node in graph.nodes if node.op_type == "QuantizeLinear"] + if quantize_nodes: + zero_point_example = graph.get_init_by_name(quantize_nodes[0].input[2]) + if zero_point_example and zero_point_example.data_type == (TensorProto.INT8): + # quantize node exists and has INT8 input + use_uint8_if_quantized = False + return use_uint8_if_quantized + + +def _set_attention_mask_to_dynamic(model: ModelProto) -> ModelProto: + # set the attention mask to be of the dynamic shape + attention_mask_input = [ + input.name for input in model.graph.input if input.name == "attention_mask" + ] + if not attention_mask_input: + raise ValueError("Could not find `attention_mask` input in model") + if len(attention_mask_input) > 1: + raise ValueError( + "Found multiple `attention_mask` inputs in model, expected only one" + ) + + model.graph.input[1].type.tensor_type.shape.dim[ + 1 + ].dim_param = "past_sequence_len + 1" + return model diff --git a/src/sparseml/exporters/transforms/kv_cache/configs.py b/src/sparseml/exporters/transforms/kv_cache/configs.py new file mode 100644 index 00000000000..9d3643041c3 --- /dev/null +++ b/src/sparseml/exporters/transforms/kv_cache/configs.py @@ -0,0 +1,167 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import BaseModel, Field + +from sparseml.exporters.transforms.kv_cache.positions_adjustment_codegen import ( + PositionsAdjustmentCodeGen, +) +from sparseml.exporters.transforms.kv_cache.positions_adjustment_opt import ( + PositionsAdjustmentOPT, +) + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["get_kv_cache_config"] + + +class KeyValueCacheConfig(BaseModel): + model_name: str = Field( + description="The name of the model type. This should correspond to " + "the `model_type` field in the transformer's `config.json` file." + ) + positions_adjustment_transform: Any = Field( + description="The class to use to transform the positional embeddings. " + "This should be a subclass of `PositionsAdjustmentBase`. Note: In the " + "future, when we encounter models that are more complex than just " + "editing the positions in the model, we can make this transformation more " + "general." + ) + key_num_attention_heads: str = Field( + description="The key to use to get the number of attention heads from the " + "transformer's `config.json` file." + ) + key_num_embedding_hidden_size: str = Field( + description="The key to use to get the hidden size " + "from the transformer's `config.json` file." + ) + num_attention_heads: Optional[int] = Field( + description="The number of attention heads." + ) + hidden_size_kv_cache: Optional[int] = Field( + description="The hidden size of the key/value cache. " + ) + multiply_batch_by_num_att_heads: bool = Field( + default=False, + description="Whether or not to internally multiply " + "the batch size by the number of attention heads. " + "This is used to reduce the number of dimensions in " + "the key/value cache.", + ) + transpose_value_input: Optional[Tuple[int, int, int, int]] = Field( + default=None, + description="The transpose indices to apply to the value of " + "the kv cache. If this is not provided, no transpose will " + "be applied.", + ) + transpose_key_input: Optional[Tuple[int, int, int, int]] = Field( + default=None, + description="The transpose indices to apply to the key of " + "the kv cache. If this is not provided, no transpose will " + "be applied.", + ) + + class Config: + arbitrary_types_allowed = True + + +OPT_CONFIG = KeyValueCacheConfig( + model_name="opt", + positions_adjustment_transform=PositionsAdjustmentOPT, + key_num_attention_heads="num_attention_heads", + key_num_embedding_hidden_size="hidden_size", + transpose_value_input=None, + transpose_key_input=None, + multiply_batch_by_num_att_heads=True, +) + +CODEGEN_CONFIG = KeyValueCacheConfig( + model_name="codegen", + positions_adjustment_transform=PositionsAdjustmentCodeGen, + key_num_attention_heads="n_head", + key_num_embedding_hidden_size="n_embd", + transpose_value_input=(0, 2, 1, 3), + transpose_key_input=None, + multiply_batch_by_num_att_heads=False, +) + + +def get_kv_cache_config( + model_path: str, supported_configs: List[BaseModel] = [OPT_CONFIG, CODEGEN_CONFIG] +) -> KeyValueCacheConfig: + """ + Get the kv cache config for the model at the given path. + + :param model_path: The path to the directory containing + the transformers model. It is assumed that + the `config.json` file (as supplied by the + transformers models) is in this directory. + :param supported_configs: The list of supported configs. + If the model type is not in this list, + a warning will be logged and the first + config in the list will be returned. + :return: The kv cache config for the model. + """ + transformers_config = _get_transformers_config(model_path) + model_name = transformers_config["model_type"] + + kv_cache_config = [ + kv_cache_config + for kv_cache_config in supported_configs + if kv_cache_config.model_name == model_name + ] + if len(kv_cache_config) == 0: + _LOGGER.warning( + f"Could not find a kv cache config for model type: {model_name}." + ) + return None + + kv_cache_config = kv_cache_config[0] + + # set the number of attention heads and the hidden size of the kv cache + num_attention_heads = transformers_config.get( + kv_cache_config.key_num_attention_heads + ) + hidden_size_kv_cache = ( + transformers_config.get(kv_cache_config.key_num_embedding_hidden_size) + // num_attention_heads + ) + kv_cache_config.num_attention_heads = num_attention_heads + kv_cache_config.hidden_size_kv_cache = hidden_size_kv_cache + + _LOGGER.info("Properly configured arguments for KV Cache Transformation") + return kv_cache_config + + +def _get_transformers_config(model_path: Union[str, Path]) -> Dict[str, Any]: + # from the model path, get the config.json file and return it as a dict. + model_path = Path(model_path) if isinstance(model_path, str) else model_path + + if not model_path.is_dir(): + raise ValueError( + f"`model_path` is expected to be a directory, found {model_path}" + ) + config_file = [file for file in model_path.iterdir() if file.name == "config.json"] + config_file = config_file[0] + + with open(config_file) as f: + config = json.load(f) + _LOGGER.info(f"Loaded config file {config_file} for model: {config['model_type']}") + return config diff --git a/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_base.py b/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_base.py new file mode 100644 index 00000000000..77072fdac2d --- /dev/null +++ b/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_base.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy + +from onnx import ModelProto + +from sparseml.exporters.transforms.onnx_transform import OnnxTransform + + +class PositionsAdjustmentBase(OnnxTransform): + + POSITIONS_NAME = "positions" # matches intermediate var name in torch + + @classmethod + def add_positions_input(cls, model: ModelProto) -> ModelProto: + """ + Adds positions as an input to the model + + :param model: model to update + :return: updated model + """ + # positions tensor should have shape equal to input_ids + input_ids_info = [ + input_info + for input_info in model.graph.input + if input_info.name == "input_ids" + ][0] + if not input_ids_info: + raise RuntimeError( + f"{cls.__name__} - unable to find 'input_ids' in model input" + ) + + positions_input_info = deepcopy(input_ids_info) + positions_input_info.name = cls.POSITIONS_NAME + model.graph.input.append(positions_input_info) + return model diff --git a/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_codegen.py b/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_codegen.py new file mode 100644 index 00000000000..6027b3dd270 --- /dev/null +++ b/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_codegen.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onnx import ModelProto, NodeProto + +from sparseml.exporters.transforms.kv_cache.positions_adjustment_base import ( + PositionsAdjustmentBase, +) +from sparseml.exporters.transforms.utils.matching import get_structural_matches +from sparseml.onnx.utils import ONNXGraph, find_orphaned_nodes + + +__all__ = ["PositionsAdjustmentCodeGen"] + + +class PositionsAdjustmentCodeGen(PositionsAdjustmentBase): + + # The pattern that matches the node that creates + # the `position_ids` tensor + POSITION_IDS_MATCHING_PATTERN = dict(op_type="Range") + + def transform(self, model: ModelProto) -> ModelProto: + """ + 1. Adds `positions` as an input to the model + 2. Finds the node that initially creates the `position_ids` tensor + 3. Updates the node to use the positions input instead of + computing it from the Range op + + :param model: model to update + :return: updated model + """ + model = self.add_positions_input(model) + position_ids_node = self.find_position_ids_range_node(model) + model = self._update_position_embeddings_for_graph_input( + model, position_ids_node + ) + return model + + def find_position_ids_range_node(self, model: ModelProto) -> NodeProto: + """ + Find the node that creates the `position_ids` tensor + :param model: the ONNX model + :return: the node that creates the `position_ids` tensor + """ + graph = ONNXGraph(model) + position_ids_node = get_structural_matches( + graph, **self.POSITION_IDS_MATCHING_PATTERN + ) + if len(position_ids_node) != 1: + raise ValueError( + f"Expected to find 1 position node, found {len(position_ids_node)}" + ) + return position_ids_node[0].node + + def _update_position_embeddings_for_graph_input( + self, model: ModelProto, position_embeddings_ids_node: NodeProto + ) -> ModelProto: + + graph = ONNXGraph(model) + node = position_embeddings_ids_node + child_node = graph.get_node_children(node)[0] + # child_node is the `Unsqueeze` node + assert ( + child_node.op_type == "Unsqueeze" + ), f"Expected to find `Unsqueeze` node, found {child_node.op_type}" + output_to_replace = node.output[0] + self.log_match(child_node) + for idx, input_name in enumerate(child_node.input): + if input_name == output_to_replace: + graph.update_node_input(child_node, self.POSITIONS_NAME, idx) + + orphaned_nodes = find_orphaned_nodes(model, node) + [self.log_match(node) for node in orphaned_nodes] + graph.delete_nodes(orphaned_nodes) + graph.update() + graph.delete_unused_initializers() + + return model diff --git a/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_opt.py b/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_opt.py new file mode 100644 index 00000000000..c35fb7efca8 --- /dev/null +++ b/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_opt.py @@ -0,0 +1,133 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onnx import ModelProto, NodeProto + +from sparseml.exporters.transforms.kv_cache.positions_adjustment_base import ( + PositionsAdjustmentBase, +) +from sparseml.onnx.utils import ONNXGraph, find_orphaned_nodes + + +__all__ = ["PositionsAdjustmentOPT"] + + +# name position embeddings weights +_EMBED_POSITIONS_ID = "model.decoder.embed_positions.weight" + + +class PositionsAdjustmentOPT(PositionsAdjustmentBase): + """ + Base class for model architecture specific transforms to adjust graph + to take input_id positions as an argument rather than computing them + based on input. This provides a better source of truth rather than + computing in a static graph where factors such as number of tokens, + cache size, and padding may affect accurate, efficient static + computation of the position indices. + + Positions should be the same shape as input_ids where each value + is the corresponding integer position of the token id in the overall + sequence. Padding tokens are not counted towards positions and + should be inputted as 0. + + When running a model for a single input id with `n` previously + processed tokens (prompt seq len + number of tokens generated already) + + This transform will replace the input to the position embeddings gather + by an explicit onnx graph input. Will delete any operations + used to compute the positions that are now no longer used in the + graph. Optionally keeps an offset `Add` node that is unique to the + OPT graph. + + Transforms + ``` + | Graph computed positions + | | + | Add (Optional) + | | + | Gather(model.decoder.embed_positions.weight) + ``` + Into + ``` + | Explicit Graph input (deletes now orphaned nodes to compute positions) + | | + | Add (Optional) + | | + | Gather(model.decoder.embed_positions.weight) + ``` + + """ + + def transform(self, model: ModelProto) -> ModelProto: + model = self.add_positions_input(model) + position_embeddings_node = self.find_embed_positions_gather_node(model) + model = self._update_position_embeddings_for_graph_input( + model, position_embeddings_node + ) + return model + + @classmethod + def find_embed_positions_gather_node(cls, model: ModelProto) -> NodeProto: + for node in model.graph.node: + if node.op_type != "Gather": + continue + if node.input[0] == _EMBED_POSITIONS_ID: + # found the embed_positions_gather_node + return node + raise RuntimeError( + f"Unable to find position embeddings gather node with id " + f"{_EMBED_POSITIONS_ID} in {cls.__name__}" + ) + + def _update_position_embeddings_for_graph_input( + self, model: ModelProto, position_embeddings_node: NodeProto + ) -> ModelProto: + graph = ONNXGraph(model) + + # select target node to update as positions input + position_embeddings_parent = graph.get_node_single_parent( + position_embeddings_node, index=1 + ) + + if not isinstance(position_embeddings_parent, NodeProto): + raise RuntimeError( + f"Unable to find input to position embeddings node: " + f"{position_embeddings_node.name} as a node in the given model" + ) + + if position_embeddings_parent.op_type == "Add": + # OPT has a special Add offset for position ids, allow this + # to be where positions are fed instead + target_update_node = position_embeddings_parent + target_input_idx = 0 # assume positions are first input to the Add + else: + target_update_node = position_embeddings_node + target_input_idx = 1 # gather idxs + + # reroute target node input to the positions graph input + old_positions_input = target_update_node.input[target_input_idx] + target_update_node.input[target_input_idx] = self.POSITIONS_NAME + graph.update() + self.log_match(target_update_node) + + nodes_to_delete = find_orphaned_nodes( + model, graph.get_node_by_output_id(old_positions_input) + ) + [self.log_match(node) for node in nodes_to_delete] + + graph.delete_nodes(nodes_to_delete) + graph.update() + graph.delete_unused_initializers() + + return model diff --git a/src/sparseml/exporters/transforms/matmul_add_to_matmulinteger_add_cast_mul.py b/src/sparseml/exporters/transforms/matmul_add_to_matmulinteger_add_cast_mul.py index 06835336c55..2cef33a6226 100644 --- a/src/sparseml/exporters/transforms/matmul_add_to_matmulinteger_add_cast_mul.py +++ b/src/sparseml/exporters/transforms/matmul_add_to_matmulinteger_add_cast_mul.py @@ -35,6 +35,8 @@ class MatMulAddToMatMulIntegerAddCastMul(OnnxTransform): A transform for converting a MatMul with kernel and bias into a quantized representation + If add or bias initializer does not exist, the bias is skipped + ``` | weight (initializer) | | @@ -44,9 +46,9 @@ class MatMulAddToMatMulIntegerAddCastMul(OnnxTransform): | | | | Q/Dq optional Transpose | | | - | MatMul bias (initializer) + | MatMul bias (initializer) (optional) | | | - | Add + | Add (optional) ``` (where `Q` is QuantizeLinear, and `Dq` is DequantizeLinear) into @@ -55,7 +57,7 @@ class MatMulAddToMatMulIntegerAddCastMul(OnnxTransform): | | | MatMulInteger (with constant uint8 kernel) | | - | Add (constant bias + zero point correction) + | Add (constant bias + zero point correction) (optional) | | | Cast (INT32 -> FP32) | | @@ -78,16 +80,18 @@ def transform(self, model: ModelProto) -> ModelProto: optional_node("Transpose"), ], ], - children_ops=[["Add"]], + children_ops=[[optional_node("Add")]], ) for match in matches: - # NOTE: bias could be either input 0 or 1 of add node - bias_init = graph.get_init_by_name(match.children[0][0].input[1]) - if bias_init is None: - bias_init = graph.get_init_by_name(match.children[0][0].input[0]) + add_node = match.children[0][0] + bias_init = None + if add_node: + # NOTE: bias could be either input 0 or 1 of add node + # if add does not have a bias initializer, + # still do conversion, but do not fold the bias add to rescale + bias_init = graph.get_init_by_name(match.children[0][0].input[1]) if bias_init is None: - # bias initializer for add not present - continue + bias_init = graph.get_init_by_name(match.children[0][0].input[0]) self.log_match(match) self._transform_match(graph, model, match, bias_init) return model @@ -121,8 +125,8 @@ def _transform_match( input_quantize_params=input_quantize_params, weight_quantize_params=weight_quantize_params, bias_initializer=bias_init, - bias_add_name=add.name, - target_output=add.output[0], + bias_add_name=add.name if add else None, + target_output=add.output[0] if add and bias_init else None, transpose_weight=opt_transpose is not None, ) @@ -134,4 +138,6 @@ def _transform_match( if len(graph.get_node_children(input_quant)) == 1: self.delete_node_deferred(input_quant) self.delete_node_deferred(matmul) - self.delete_node_deferred(add) + if bias_init is not None: + # add converted to quantized - delete previous add node + self.delete_node_deferred(add) diff --git a/src/sparseml/exporters/transforms/matmul_to_matmulinteger_cast_mul.py b/src/sparseml/exporters/transforms/matmul_to_matmulinteger_cast_mul.py index 1b6c8983360..73684cf0105 100644 --- a/src/sparseml/exporters/transforms/matmul_to_matmulinteger_cast_mul.py +++ b/src/sparseml/exporters/transforms/matmul_to_matmulinteger_cast_mul.py @@ -76,9 +76,12 @@ def transform(self, model: ModelProto) -> ModelProto: op_type="MatMul", ) for match in matches: + is_parameterized = False for quantize_linear_parent in [match.parents[0][0], match.parents[1][0]]: if graph.get_init_by_name(quantize_linear_parent.input[0]): - continue + is_parameterized = True + if is_parameterized: + continue self.log_match(match) self._do_transform(model, match) return model diff --git a/src/sparseml/exporters/transforms/utils/add_quantized_conv_matmul_add_ops.py b/src/sparseml/exporters/transforms/utils/add_quantized_conv_matmul_add_ops.py index aafa5ddbfb4..a626db04ddb 100644 --- a/src/sparseml/exporters/transforms/utils/add_quantized_conv_matmul_add_ops.py +++ b/src/sparseml/exporters/transforms/utils/add_quantized_conv_matmul_add_ops.py @@ -35,7 +35,7 @@ def add_quantized_conv_matmul_add_ops( weight_quantize_node: NodeProto, input_quantize_params: QuantizationParams, weight_quantize_params: QuantizationParams, - bias_initializer: TensorProto, + bias_initializer: Optional[TensorProto], bias_add_name: str, target_output: str, transpose_weight: bool, @@ -49,6 +49,13 @@ def add_quantized_conv_matmul_add_ops( Adds new quantized ops to graph, does not perform any checks or deletions (should be called by the operator main conversion function) """ + node_output_orig = node.output[0] + if not target_output and ( + any(output.name == node_output_orig for output in model.graph.output) + ): + # original node output is a graph output, make that the quant block + # output target id + target_output = node_output_orig # Quantize weights and add to graph quantized_weight_initializer = _quantize_weight_initializer( @@ -65,30 +72,43 @@ def add_quantized_conv_matmul_add_ops( ) model.graph.node.append(integer_op_node) - # Add bias + zero point correction; quantize bias and add it to graph - ( - quantized_bias_initializer, - quantized_bias_scale, - quantize_bias_zero_point, - ) = _quantize_bias( - node, - bias_initializer, - input_quantize_params, - weight_quantize_params, - bias_add_name, - ) - model.graph.initializer.append(quantized_bias_initializer) - model.graph.initializer.append(quantized_bias_scale) - model.graph.initializer.append(quantize_bias_zero_point) + if bias_initializer is not None: + # Add bias + zero point correction; quantize bias and add it to graph + ( + quantized_bias_initializer, + quantized_bias_scale, + quantize_bias_zero_point, + ) = _quantize_bias( + node, + bias_initializer, + input_quantize_params, + weight_quantize_params, + bias_add_name, + ) + model.graph.initializer.append(quantized_bias_initializer) + model.graph.initializer.append(quantized_bias_scale) + model.graph.initializer.append(quantize_bias_zero_point) + + # Create Quantized Bias Add node and add it to graph + qadd_node = _create_qadd_node( + node, + integer_op_output="{}_quant".format(node.output[0]), + quantized_bias_name=quantized_bias_initializer.name, + output_quantize_node=output_quantize_node, + ) + model.graph.node.append(qadd_node) - # Create Quantized Bias Add node and add it to graph - qadd_node = _create_qadd_node( - node, - integer_op_output="{}_quant".format(node.output[0]), - quantized_bias_name=quantized_bias_initializer.name, - output_quantize_node=output_quantize_node, - ) - model.graph.node.append(qadd_node) + # bias has same scale as future rescale op + rescale_scale = quantized_bias_scale + mul_input_node_name = qadd_node.name + else: + rescale_scale = _create_rescale_init( + node, input_quantize_params, weight_quantize_params + ) + model.graph.initializer.append(rescale_scale) + # cast node should come directly after quantize op output instead of add + output_quantize_node = output_quantize_node or integer_op_node + mul_input_node_name = output_quantize_node.name # create Cast node and add it to graph cast_node = _create_cast_node( @@ -100,9 +120,11 @@ def add_quantized_conv_matmul_add_ops( # create Mul node for rescale mul_node = _create_mul_node( cast_node_output=cast_node.output[0], - quantized_bias_scale_name=quantized_bias_scale.name, - quant_add_name=qadd_node.name, + rescale_scale_name=rescale_scale.name, + input_node_name=mul_input_node_name, target_output=target_output, + model=model, + node_output_orig=node_output_orig, ) model.graph.node.append(mul_node) @@ -111,15 +133,22 @@ def add_quantized_conv_matmul_add_ops( def _create_mul_node( cast_node_output: str, - quantized_bias_scale_name: str, - quant_add_name: str, + rescale_scale_name: str, + input_node_name: str, target_output: str, + model: ModelProto, + node_output_orig: str, ) -> NodeProto: mul_node_inputs = [ cast_node_output, # a - quantized_bias_scale_name, # b -> rescale factor + rescale_scale_name, # b -> rescale factor ] - mul_node_name = "{}_rescale_mul".format(quant_add_name) + mul_node_name = "{}_rescale_mul".format(input_node_name) + if target_output is None: + target_output = mul_node_name + # since we skip the add conversion, + # update model to point all outputs of matmul/conv to the rescale mul + _update_model_input_id(model, node_output_orig, target_output) mul_node = onnx.helper.make_node( "Mul", mul_node_inputs, @@ -129,6 +158,13 @@ def _create_mul_node( return mul_node +def _update_model_input_id(model: ModelProto, old_id: str, new_id: str): + for node in model.graph.node: + for idx, input_name in enumerate(node.input): + if input_name == old_id: + node.input[idx] = new_id + + def _create_cast_node( quant_add_name: str, output_quantize_node: Optional[NodeProto] = None ) -> NodeProto: @@ -253,6 +289,15 @@ def _quantize_bias( ) +def _create_rescale_init( + node, input_quantize_params, weight_quantize_params +) -> TensorProto: + output_scale = input_quantize_params.scale * weight_quantize_params.scale + return numpy_helper.from_array( + numpy.asarray(output_scale), name=f"{node.name}_quant.rescale.scale" + ) + + def _quantize_weight_initializer( node: NodeProto, weight_quantize_params: QuantizationParams, diff --git a/src/sparseml/onnx/utils/graph_editor.py b/src/sparseml/onnx/utils/graph_editor.py index e2887f17bc4..72cd832b5b7 100644 --- a/src/sparseml/onnx/utils/graph_editor.py +++ b/src/sparseml/onnx/utils/graph_editor.py @@ -24,7 +24,7 @@ from onnx import ModelProto, NodeProto, TensorProto, numpy_helper from toposort import toposort_flatten -from sparseml.onnx.utils.helpers import get_node_params +from sparseml.onnx.utils.helpers import get_node_output_nodes, get_node_params __all__ = [ @@ -36,6 +36,7 @@ "prune_unstructured", "prune_model_one_shot", "prune_model_one_shot_iter", + "find_orphaned_nodes", ] @@ -195,10 +196,10 @@ def delete_nodes(self, nodes: List[NodeProto]): deletes the given nodes from the graph :param nodes: list of nodes to delete """ - node_ouptut_ids_to_delete = {node.output[0] for node in nodes} + node_output_ids_to_delete = {node.output[0] for node in nodes} nodes_to_keep = [] for node in self._model.graph.node: - if node.output[0] in node_ouptut_ids_to_delete: + if node.output[0] in node_output_ids_to_delete: self._delete_node_edges(node) else: nodes_to_keep.append(node) @@ -450,3 +451,58 @@ def prune_model_one_shot_iter( pruned_weight_val = prune_unstructured(weight.val, sparsity) update_model_param(model, weight.name, pruned_weight_val) yield (index + 1) / len(nodes) + + +def find_orphaned_nodes(model: ModelProto, node: NodeProto) -> List[NodeProto]: + """ + Given a node, that is to be removed from the graph, find all nodes that + will be orphaned as a result of the removal. Orphaned nodes are nodes + that will have no inputs after the removal of the given node. + The method traverses the graph upwards from the given node until + a node with multiple outputs is found. All nodes that are traversed + are considered orphaned and will be removed. + + :param model: The model that the node belongs to + :param node: The node to remove + :return: A tuple of the model and a list of orphaned nodes + """ + graph = ONNXGraph(model) + nodes_to_delete = [node] + # start queue with previous positions input node + queue = [node] + while queue: + current_node = queue.pop(0) + if not isinstance(current_node, NodeProto): + continue + node_parents = graph.get_node_parents(current_node) + # if node parent has only one output (current child) + # than it is orphaned and will be removed. + # continue traversing the graph upwards until + # a node with output that is not current child is found + for parent in node_parents: + + if not isinstance(parent, NodeProto): + # if parent is not a node, it is a graph input + # and should not be removed + continue + elif parent.op_type == "Constant": + # if constant node is found, + # automatically remove it and continue traversing + nodes_to_delete.append(parent) + + parent_output_node_names = set( + n.name for n in get_node_output_nodes(model=model, node=parent) + ) + if len(parent_output_node_names) == 1: + # if parent has only one output, it is orphaned + queue.append(parent) + nodes_to_delete.append(parent) + elif not parent_output_node_names.difference( + set(n.name for n in nodes_to_delete) + ): + # if parent has multiple outputs, but they are all already in the + # nodes_to_delete list, it is orphaned + queue.append(parent) + nodes_to_delete.append(parent) + + return nodes_to_delete diff --git a/src/sparseml/pytorch/base.py b/src/sparseml/pytorch/base.py index bfacaaab6f4..a2dafde673a 100644 --- a/src/sparseml/pytorch/base.py +++ b/src/sparseml/pytorch/base.py @@ -49,7 +49,7 @@ _TORCH_MIN_VERSION = "1.0.0" -_TORCH_MAX_VERSION = "1.13.100" # set bug to 100 to support all future 1.9.X versions +_TORCH_MAX_VERSION = "2.0.100" def check_torch_install( diff --git a/src/sparseml/transformers/__init__.py b/src/sparseml/transformers/__init__.py index 83cb82e9245..d96053fc1e3 100644 --- a/src/sparseml/transformers/__init__.py +++ b/src/sparseml/transformers/__init__.py @@ -48,6 +48,6 @@ def _check_transformers_install(): ) -_check_transformers_install() +# _check_transformers_install() from .export import * diff --git a/tests/sparseml/exporters/transforms/test_matmul_add_to_matmulinteger_add_cast_mul.py b/tests/sparseml/exporters/transforms/test_matmul_add_to_matmulinteger_add_cast_mul.py index 103626a4bca..f8f32e5d3b8 100644 --- a/tests/sparseml/exporters/transforms/test_matmul_add_to_matmulinteger_add_cast_mul.py +++ b/tests/sparseml/exporters/transforms/test_matmul_add_to_matmulinteger_add_cast_mul.py @@ -146,7 +146,7 @@ def test_without_transpose(onnx_model: onnx.ModelProto): ] -def test_no_bias_changes_nothing(onnx_model: onnx.ModelProto): +def test_matmul_no_bias_converts(onnx_model: onnx.ModelProto): # remove "bias" initializer and "add" node assert onnx_model.graph.initializer.pop().name == "bias" assert onnx_model.graph.node.pop().name == "add" @@ -154,17 +154,14 @@ def test_no_bias_changes_nothing(onnx_model: onnx.ModelProto): onnx_model = MatMulAddToMatMulIntegerAddCastMul().apply(onnx_model) validate_onnx(onnx_model) - # NOTE: nothing changes + # converted model should have matmulinteger + rescale mul without bias add assert [i.name for i in onnx_model.graph.initializer] == [ - "x_scale", - "y_scale", "zero_point", - "weight", + "matmul.weight_quantized", + "matmul_quant.rescale.scale", ] assert [n.name for n in onnx_model.graph.node] == [ - "input_dequant", - "weight_quant", - "weight_dequant", - "transpose", - "matmul", + "matmul_quant", + "matmul_bias_add_quant_cast", + "matmul_quant_rescale_mul", ] diff --git a/tests/sparseml/pytorch/model.onnx b/tests/sparseml/pytorch/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9d2051df3e4979da6e97c02b06350df34507749c GIT binary patch literal 23435 zcmce-c_5Wt*FS#9l&F&;Ns>xJrb;?zuhWT=kc5OvBodM$Dbz7!o-$8y%+jPua`rm+ zxf>)&QfV$tn%$b*D%JOTruTln&-=c=?|;AJAJ?9*b?v>@XRYbuRzXKQe{snJf~&71tg6i5G6vSG7#sEWEV zQ(*k}CdO7w!}VKtgfTtWn`n&rj~X4Bzt!ujD308?ImkO~^VZ-{z7$_dC?&72q$s;- z`_>&H>dce0zN(^9pzroz-=OuO{@x+J{1N;Se`_>WQ6|KDv(2@CDwPL%J;{Vp;FS(fgH@TP@ z{fAu4{#UvDqvrp2x%|h1{*sH?|0b7zRQyveoBk5Ze{@jMHV)akHOP2r(AEv!L5sqD zw|j5$b!1BZ67_`trG?Je|JF&-Qt|&ep2`?w=fCZdzb26TM}dy)-^&%PR8;;Z{?IV* z;IO|Mjr^}N9fg0@>Z>cthI z|22>NKTCAv{!ybpN>Ro;VsofEi`nOj)+#DngM%YDoK?Z0+jjW+?)F`8YObO9uLZ5Q z__xJr{A+Qe6{P-ic}|M|ZDz&)VPutY#(z!pU#C>@uR@)X|5dNJOlACk$!NXlUvD(( z-%9^KmiOzQzrRUujpF|_vEu(YJ|jcs^?khm?OVzJvqVSkA2t8C3_bqedxVkBe~hUz zmf7fk*DZg&n8Lq`bmae8rseV*E^NDZaH!?qm7*d? z@qZp%MeZM;QiJ{XNkTVoGIrj)Y0K8l8+GLW`H#xTzpME50h`10A1f=c6-NB+k1b{S z-?Y1}MET-MDe;E-T*y5=2R!Nq!PfpaSpG|j7q`-c7ks1(R00!md;Ce5@4txG_4zQ) zzcho_uUiI2OOt?AcbK!M`xjJ2wczmj#b7@#8q62QLcqyWV12p=>dT(upAQsyA&C>Q z86)2PMhqd*f8dw&nLJ0AwV=(J zC*G3J7MrqX^ZdIJb4^FVz>P8DlK3BZzIBFJ^{qbdWzHg=?dCgh-b9+GwYCdIkymjI zdoC}zRD;*+8jF`o@=>#2Pu#WUJ9Sya2buZacxlyqvFh_#ypY(rAQn9ZO}{AQzhVPx zyQ=8Vn~QmpuTx=BIiI)O%7Q1;yA0PT%J2$|*gV;eQ^+~{17fOVd5`yx;4PS`!TTkc z%{v_C3H5d_N%Q)RY$aP=`(aAxwXpIrv8+;3co?KYA0$aZMI zgIB%Bh#xuEL-iVEn0v+ty3H-c`cZAr(3J}pr%dGOFze_TwGGxDR>xQIhGH*;2pk}x zRJ}w(i(h_&Z%qZLy;fVyQaVKkg4SRNdoK05lO^h|xq%19zrrb*Q+ch^lffi!G3<{q z607CbfY5geK8}5j+4m2_O1TQO36r4SK~wDh@fcW}X5z^VOCH;7Gz`8Hz#m{L?%io6 z)=3nDG=H*qy`q%(#G=XKH_~QeL1%eOcWx}Ku{9L4%uk38&*9;=B^5Z=aRL^*bpY$S z6W+Ob2^-b+;OkEu-rVQPxZ&<)INYEr<~66lbiGlyPJcfdu383-+xF0X8V^xN(OCTX zJsY|Tx1r=}F43A+4|6@Hh=rl4qK}3*z#--{oaj0RNz+eK`MU$S=4v@gxQnQWHIAn= zYk~OFr*3q9YtHNHTL{g!3sAjBLW8T~;X&7Q5F}TNtdGP)X2g84&Z^h2&vyaO z)@8i-MDR~&oDxP9f<~k6)|(jaD=jwY)Z|GV&v9p^YVufWZ(0WTyvD3!eA*^I5;8_6 zVc|=F>l__kRSS<-RVODd-TV`wY7#FbdNuvhH=4JtVLs2K>i`ulO%nws%;M20T=A=$ zv6$tl2q_A0;pl=tVFt|Q-Md;1?5}4@LVrAX{dh-=3^c`x`|jhLgIU0q@`G{(5zebx z%qt7M0O3ncKw7x~M>t(X{`El;6=;C^lP8ELMqR@FlSc7OLoIoEo1S9(?pA2Ib`RA1 zY@skD9VOL2NXDh#!O`^rZ5Qf@7nrE=WCM+Oc5!ogn%Cu!r83OnO)bJRk#p7g67hjm^0<3EW zblle|yv4HHptbHguy%Y90*=bqp{xB6n>dZ5q}Mz z#EZHRiWy5{;H>>P-mp>uW;;JYg}r~$W5?9RKa5m(ruQcEl)f857n+IHi+`dt~9GO(u$;pI!{( z^ScTT z{Feh@E69@SiN8Zz4@^Ws!8h)kvq9jr#0}n!uf*<;bD@hKCX$*_$Qtu(fo9ea)%zxg z_p2gt^j&$_KGqn@4HIxKyN0ehT?x&(+Aw6|j{Lo^xvZ047aZ<-LsI7Zz-~hg2>Fta z64hkR3X5zwwl)cxjy-duxgRj>$yc zAO-SsCu5i0BU-XF8@irUie4Pb0F$B?(eTvAYa!uTGj03gf;o!O$Qofp(m!Q^YGnfC+z7@iaVe;Mokt^DX4CVBMj=P6 z0FqrDR8y{+-00Low+q&E!ymgKUp^EYqQfESts^~B5DUG>b4cTo5imcg7L?Xzp~}Zl zV2yC5Ha^Q~(7|5P-ZL63^yG97@JFWrJfxf;yfF<0vrkY#gcR3rPBgq+91CjVdK8!_bL1Ad zqtRR^BxU({#;gFz;$&oJe58WindqOl7<+Z4DeLfB8h0rSY4&Hzn)3Cmr2j`tomM7n zY-lBE8{4T~pDVe@N(B9KUnq6HPW)V-(XK;1)M7aZ!pJQAY*0tuPt1e*!;08zAwZ*?I_x!`0PRU0 zaDH1R2%KiNj5NyCEC*)aTSfQDX;fo}IrAoKn;i8rjI&J$zc+pI8ALQVOB;dh|^c)V?Wm&LpI0(jfy8uA$jNsO2BXXox6C%1tMoP4(sdfsGFMWxOJ4Znn#=;c){=t4)XRjw;RI=Fk4KjBER@b% z2}9HUX_b;SI(uxv1gRSaKyBp$NSi!F z`ETbE^ArV0C{==xtJyf^YkB>{yLdBWx9bbm0rgf7Gr>f~d$93+2WgmHZMio-uc;LfxSx`aJ zL3TnFta(=enR!JJd($2g#GByU)D&E>I~}|{&?0d*T2$Ilw{x@T@s2{sXqUw&9s8hP z%?3`Y6@%j^S&TMI0lkqMQAeN;*$M%mIV>SM^W#vE+(OkkrgZA@Yd5g zLF_dxK+gm>Odz{}s~CtMXKaR4c@1P~msqgX%5l=TeB5s14GOd8k)wb50$!{J)@Liu z(Kmj0dvZP;jnsssMn+dI?j(11O@%X!nRxhhHoojA2gSl_R1t55pRfV+HjV@Z^Fq8} zG(;q3S*RJ?OL-GUAeC)Ex=TO{9Yd+fh^wUQK?4!I^=r{P;DTxHO|(Kc9jeJ^(&A+Y zrIV0)bOgXn_hNE+L@8W2o`{2idodyI3awXH!M*7P;K8ZG@uxnM5ZA}tnK`+jSG1V2 zRZnx(SDB#o`MD^ud(H{04#CmOxFFY-hlLtrF*+(8Dhs#6h&L0^|A+_nhooRvO)EG1 zg98ja(ShQvveW;r^4FdL8;%}QuL15WVS?f-4 z!c=~chmQi0McTe ze3kwfBY-;(t*OMbnxn8lf!LGhr1!xi8c`cf{C+z_^?U+}gVh{#Kb8YZ4!y)eH-Knw za1-rG=L2g&32if1z?;uoh}7N;xK|beGOAVdN-+y}n|Z=S#~NsRWKQ+XN5QL{Q1bai zJP|0*;oe*tPtDb( zBFV&4bWBPXiicJ}eB@;kHY)~;KjdI`z7I$&FLQMq?Z8WZKmD?*3U+BWfYBlY98r>p za%T)cPgITD8T^~DZaxI8p998iMPSL;fX2z!Y4+avFm`S+Xde!veW|6m<9Zf&n#$q^ zU46J}UI~8J765xk3EB^(QA__CG*Lf8-!OiLglr*#^NKC9w`*XaM-AMU)Pg9}5W}qA zk)auD={(0W7@qr_@a2m*tcIPUje|qPY@+~Yd6t0!^(3xY%y+cNg!X-rU^Ghx&0cMT zQA;afta<@>w)Bavph;ss&5sYk9e2~A`11s~vOovFWqJT#X(iff zxq!*e8ECoK1D-Bg336$_Xk*$p?zB~-QTTWj-Fh___x!dFg#kX?X*Wh-3K)Up&vwzL z7kjZ~?rdn99g5e5h0wiE5A4mWY1Ef)Dp6yj`sIma8aDxUsHTJZI0O2i=qbrKWC-08 z7+vI_jFNjMRH-4GmU!~9b($4CtT_OzPqw15-C3ahrvWV*Uj%YjZ_}=E-$em|9#9(c zkqjJ36A{T9BKxQe-k#fnos-nTZ_+Q?)fA4|GRr|qvjV)%8eoS0G|1-&ahGxf^nTRD zR}NYv;d&uv?<0_EodzCT*Tda&%-I~Owc0bg6 z%|hLKD?#Aw($de-Lf6)2I?r_l&_T%n=w6Y#-s z0}cKVjqHOdqO_bd^o~dw%{nX~dVUh)J6{()?`JUN_$qGts083=zutP!!V;X|oRSpdPd`2iQl@a5U1!j!x&qo9HJ+@S zRt#@GS-_{%Ik<9mEVOpKp+loB2<`tuSuY1fcXs>Yqvf;k^5!gv@QC#MPPhV|>`WL1&m}3u#_6I>>w$g~}0 zXz%rz+FRDqo>|j@GioNNePR5Qjkaii_daoKK1pT65*u8WnoO)JD>iN@Xu{aFA-JK7tSp^)I zUG+E+I2xMkR-mg+7=zmu#Deq86qM;qK;OgF z@O)w}9bTMI_0-sKR%0K8*r$*~qcg!Jz7|JkEPzfOTkMot3~RlsaA@QyAp8er!BWL!By0R672aOX@uX4vmSYbC&MM(&_B zc@vJ7o{W0StjNu!$#e!4!lVnaz~>Fn*(rGtV3UMzXJ){w?OsHN?kC&t@FDSP7N)La zaQ5R5bSDvDyn6;}$2!w{c8ft0{DDY{vS~u|deBbtryVY*sOHyHn$V(f)9#nug^ z@2!St$5Q6IW-hQk@`Ku0+TlmW772bW6CE}xr48TKz*|j%pA;9P{y{egfmd{#mm0=B z+l?o}b-_!2BdV9Nq598c`bBp?T$xo1YjXz4;A$s$^>GG<%x&ey%@JbSoC6r5w;Z#Y z@~|WHCym#+O$>($X_dw{tY73#CAUmzn>8COzs-c)$>n5cR2j0_d0dOj%A_=k(a(2B z0Q<)ikz(jFNQ$08`KPNnA$vDr<%0winsd0#=j;eeUwVO8ZW&d1tq;nf$)K@#Axs;^ z#}Rk4F>cQR)S1cT6P`_l?9bj{X0`!(`8{;!sWRAYk%o_gQZPO@mmbZX2&tU)sMq94 z;-|LGZ|R@2T_kVB7Jd(sf+0?Qg68d zNqsFHTDg$&=B!56m5*mlp4_5`Wb9DzZk$NS;v&EIGRO11B6i+(!u?JKQ0b!w+Vj%U zBBCCzwA3?u$(-0J)q?DzT)6X2hE5cwLStbH5#||)EJj6<#qTc>J~=DOO3k4CwQJE| zYA=N>r+$lxo0wFJ$*sr zV#lKNi)~=fYNp3tS3%b-iOAL62x!w&;*l~DHvT+D+FMN^&$$95gK?bqnla$Pp910O zNcW_zg`g3oc+@!_wc6Nl>v|ou|L}(RwRy;z!4ZXA&%ozKHPp>r6$Tc{!$L~|>SY*W z$g7VuVYL~u9xXX*Q9X&#C#A&ja3y(PS^#19M5I?=TIBfgC-=(bNjSTs9+vMu0QFWn zwCaE|9-CPUt^KQ!?bsvAk~uG`uh1akrZMoUb~AaWc9b}-szdFEmuSiN8klBz08;Wz zfvwyxV)18^-6}z_<6R6~nHr1(oIn(RR{^<)nxNg}A>!o+L$fR!GY)OQ>IVl2t9cG- ze3L8c?HWz&m^gu~N(Ss={M%|$LhV>qsBXw6w{7y_;Unv)nHQop6dBIWBBVexcz_|n8>iPQ_U3u2h}jRb38i7JK_=p0elQ7#MX*V zN_|d}f!~TnERTaNg@+t*L1`}Zel$K?FB zLJYoSiE(J$P3vRS6z!%!y+0kd&rl=mif(~k~g)WVBY}>{n z4>rCdVJjPskW>et#-iJg$cD6j|?;!Q<&L|1nNBB+~sn|RKow*fg z`)Uh#p0owQvNxixL1+9pDFk|lHsj~t2Z`l4Ik@tgV6!lk_U(>D=Y10}-(xXO{g4E+ zs$${F%DJdl?tl<nrO)z*^pt z{VNdEef){z&3bO`v>U|k&Q_G{YG5Y+#-rHv!FaT z2KFSVgXF9nR_c|YVb(CMy}uC5xi{(09a8XWiXyF@V2Q2YSAg8o>F7CM4P&4XbXJsO zP}DKvJ2wudvZ_GL+Q#T87B#x+3|R(8xiLO*u;-;aWH9*EV=@OEV;Z>Dc}Hk_iU~wK z98I3)GPb<&9!--@#<0(+2jzgQF@e)*j0l=>qmp*9dle=m<(~|sz@C75j)ouIGlNuj!}tWa%V1NShbx@`jH4a zUePcwGZ{DC;Q}o7_k%~h zj7_dMLr?UCLsrjKGQ{Z7bX_3`higg0o>Yt~YNe8rog`t7GL$Vg0?u&*=p8XkY~r;^ zU!W=4c*LONTOsWovw_J+PRl8G@6<;J4# z!FuTEd`db$RA9mOeDr!7ENZk+hJkB|bi_0z9602T36qv%M7s8Syi_E_rwXDDqr zz6ID1+eDJ@cS#a|Jw35%Ht5-mfwz)%SbnJhHOy>){(Mc?eR+gGw3@{4X~6PtRd7?1 zN%`J)TN2(q@+qnSDf1d+ub@R}ADS*M{4D4Mu10`0{Aow_)s}bZ0 zJ*G&%%PZo8-SO1>Mh!?s)f(oA!@ZHz1)aJD=vgVg^rMl`t`Bqzlbhm zq|<7p9@4F+4}!1b$hQVdaOh7)BY_(zT|Y(k9A%;Xt#gFGcr0#MqyTS%XG1;TmyXMG z$0)x)NQT@vygn}u@*~@*!iL$Dw`2`|+S`DA6F*Vc$(q>XFGM~0Wkhmz9Xh5yrdnD4 zm^S+|ReE1V_-UuO6>?0BHRGG;%t-+%dE{fr;2~nNsZrD(=LW?s^1u$M5Z${H3WaJb z$ko6y<_xR@#dpCV$M{9O`EnrRH9)?t@`P+@H#n~;i(UiwDeGNR^XU)e@Nj(z2#1v??$U6e{`Ka1=#9iz*EBp7yeO-ciwEGB)W+BZ=8(& zHu9)6<0E-~A%|)X$5O?|MsRs|8uZ8(dy%J{_Y514fy2aE=$AA0LVag8jH~ z@j(#U%HgTa17wZcT2y(#gP{pLWPPYw(E7pv1Sh>(c=Af1m%fH*r0c-TrW`bSvJ~D) zXW`!8V<5F_FW#6t6;1RT=z!8#5M)O&c^d<$EOmfbS7%^d?BqDF^hdv+JF)pz2(~F` zpizeh8eH6l15rs>@*xI8zP#mzO#8&G(XWFq*HS?uxZBbfD~q~{G8prjiBqz7&^=@4 zz`*QwQJLCu9As(EO%g4){dg=RTbTz9t|nqRD?pEgavPwx5XG#mkgX;BcIvK1^2j3+HQrC9pD zo`|Ngv5~eA!LoTRQhiwvG*d(;d`pC&-wu$c-s7OKU^cz-J`LQA6+wURMCfcPM6qTy z=DB8p%~5&Ub@DA~I~qZc?MOn+hefn>NeOj)yAT6*)G&Dk6H)LBd#u_Mg&nIe(C@Pd zT5T_9e8mj#i~CMfbqaw0>M?b=HVd}a=zD~MROq@9(|h% z@H6#K51~ID6p^ZYq|QrLVU1%ROxtUO$6n+kr_TkRMiMa3wPG;lJkUPB5XIF}*xEl4 ze!4Mni5g$g*Qx^edEOW$A6?KfRB$DLTAvO;)#wtC)Q0kQ28)Be zA^h?RXqu}GJ)GM#Td@#r=Bbf+X*nQj8HMb46HIDafv+5GY441DdT35QvT}jzbzYw< zI37;=GEPz7l|{f_Uyb9G51`V40lLG2iB;BJyLdIUb zow^;%ycJMiWgC=uuff5?mJt47HS}eLf%xh^41Y2k>#gL8%4-MU@-0vTwxmRR1I8cQ zk5A;7n76$O+MmuM>P7mrd~q^N9ZrNc&w1E(DV&-bFkB|fC)7)RBZk;Fi42#%CfR|@ z!9UEOZg$LuNqUnot@bieOZ-I(XBksnXEx3;D!^*Vel#+7f{^_~+@<%jU~_Cfuy%>L zSp|P`9k+?N>-OivyTa>a*_k@n@W=*AgNn$Y(Rv(_u@k%MvdG)1X?W5m8cQNdP?#Y_ zBoCfbVf}G(t_)!D4Ke8*)j>C<<$>BBV_fW+K=@KQoTky&>4P&Lh+67J+Ull?{MJ8O zUTKdfpN!XI_;+USUr55QlT)F%P#N0Yg2A`G00IMxpndi-;IF<-@9>vX@9|mCo~6XZ z6qe9YclM$EFA+WDp^OGjd3bY09TBWuLRtRTsoY6DJSWYZzx_EVcw#4dl_Eg(awTH% zQVq2)PbPw@Ib7lFWn9;#JK-deV|;)z!uOLW_w1IyQuhQfvyOnXE^#p2n@h$=-Y0_} z@6exjccAX*4tm}w7K@!$g3Yoqbob_5v@cA?3dVQH%B-jT22(H+0GEDDfX3eU}zNZuoWb%cKm8ZRKIeWjfVpS_1-CX|m_C0VoE}2DiaA7#ynsmu51#=@v(- z8+(!3UpPj~63x)$IETTUj17BPNT+}3CallbMe{FZf#02K^r|{UPKA9Ub!0zYP>Uhf zcWppdK^uqOI}xAtJFzu;CP@50wdBsN!p}{~)Mo1{(vdBppDK2u)xu1y-pgQ}xNBSq z7IU_lmcw(G2E60qgIb4n0{?=L4tly^WR@FB9nOXfi7N)RG5(&(cuabrN({1RV9(** z&@jdnlHJq6Nz()52ea|YL_b_Q@&NL#u`p|OJRKicO8QRr)0CoZV5Z^<5q`VqLWU>j z;+KtUeanf!_9^%61v3;B$%wo+WkZDK5;}AHKKQ*XnZY+RxZS%~f!6tae3c{zJ)SF| zu00#|9&)Jr)Dj2|V1vu{0vd8N0Ovo9hb|K-hMW5Sgx-MYM;egWRVbXpyf*e0G@O`J!dmJ2Hdv;`Fi6x|+*kV$AgDvP^n2IY?lOQ2^6gDo);oflAf+gWu_)Xdjoa-{6 z(sCMXYt90RR-I^WaUD&GF$8hRE}Ynw2d}Hj;Y74CRQ?f*IS&$GUabFj<~wToM++xsfiesG$wdTh8ne>28seM;@F>;u7v zb5zaB8lPU~!l%FG8Pe9E#Zn>qn9j#nXGhYkaRr3G;|fQ2!Z>hysgK{&cH-F)abP94 z9rY<2-dSFtywxh0;58Aw^g77R_)KCVVc|xLHljU#y2x#nGKe2KqW8}tNHFw*2QTWV z(>XtAU$7JGZx@h(Y*`3TEX2|ewfI;g4_zLnQj7DOi00fF6c=frP;G!lJ(LikT@aF4 zZ|Uy+)flpIB5*n5v2Bqh-E}GtJg3YBb?QXe&R0ZwpBB-$a#i#gYX;_#V*zZkK=Z~I zn%$%iHv`)ku4_F}k@kdu1-{VufH}8!Dj>;aA+jEXw#ZoKV6`K|t5N)u*zL@~2I)yq zFpz^+mM_QY;zZ&#Est38_CnB+Oj4P@0Pas`aPyL%)Qp6JRdF7gm^`Hd=>U--{|0Rf zF<|gEPt-Gc6KpSwfD?i}z`pX8t2IpqO^35!aAgs)$6pnBJ%3LW>U^lSuPTNn&BTvU zMxeB+7Wr#qNYBKnz|XRxnT{p6=3W&_Gx7DlDQn=pVI7#SS^#f(*^u`<4_L8xS{~=c zLcD!HVI3`_jY=(C*R*g@n63()Dk1QKiG`V~kHodZ`KT^QL$CJ5v}egW(0=)d3{Nkn zA)|UlpW@<>1@8DwG!{0r1jECo7;s#eNY7m~fcL%ofSb1f{iQb%&nb)0QV<5NHv7?W zpp|;PnI>}47m#O6y!_Ge2o(O|D>Cuw;r3pgiM^iJ=#L^Mci{hw3SX|Hy|GC+T|b35 zoiT=27q#GnVkRuu+Q4v`7>q{80nbSh_@5i;o~xN?WoCto7U#he^;&4GzbBHId~9J? z=8L)(^wMI}%~0834ngA>-~Y)Z&Z@xza1PMH?sWms_S_WRo}1B}pnZ5=E(g@7X_Ckr zvr*k!j#}ojLCh9n)#)0z|1c50Y)OJ$uVa8LE+jUXh6C}JXuEk5ZhBG#a~>GtBj=IG z5=3#1vgA;Aw+}o^T8LgwGPJMcC!La*4ANn8aGw(gCbmu}7;}yYvc|W#1}s2fnGE`X?2pxf1QI!6L2jJt%Kbiw`sM z;8766AzVL5v~#vm^{uK@4Q|j{yP2@Xtpr$P6p2h;hlgU)@Y_Z|(A(ohGR|6JmS|;0L&EllbP}WM)i=4*%k`;&OGe(F)uz77X z_W8|(caKhx{t8XF(XtNNYcA3JJ#r}RZU6=qdMI#7r&q*fC`p(x z4O%^lgM`5KFnvNLu66FBowb!1p)7^+%DFK84~Dz2={AjJ$6&&}9Vi;G#rzYKpvS`> zbkPdbTW5m#ip5})u$Qnj*KwZ)lpss%BrU2)fxt~O(D}eSuqf-w zVXE6J0}9VK6ZX+!u7GEZy+wT_gzJbK%*|nFQ#kgQtcI5>8gRZz0<2fe0P}crd>XM6 z=$mX}Ll$G&jbEg9s{=RO`2jWSRff(YhF3Sn3m$L6@s;%6YoUUOOR_rK zMeL7PL<8#?Ojnu7aTv|T1mhgs!?K3T6S>IRszg>>*WrPfIvC1L#lqShC^AJHA7(?QU zc45Y=Ik4#k!?iz@M*R#mp=*7s=;CNEx_&QX&d7tw>gj zIxJ?#Q=a!SuvM4>{4>QZ;~0CJ@UfHfS!pdJ*XMyvS0eV-t8!VtOt@!`Rm0KE9?;Wj z36`H{L4QjKzU__2tcRz#RaeWPzEzr-cxi!7p)Wj}y#{qtW}$752jrgUAr{)}$v{`5 zNZZs4`Aalue~T^(Chs8qhYgr>G6o8xedzqng}~mOMkPvFct^vHi8-maw8Yn;dZZ26 zG^zp2d!umh%r>-%TZC>WG#I`-!v*+b8vHUVhMR{D(Ej=gY`wG-{C{qtj;`t0-sl28 zO#ZSVZ3}FO9|MDByV3unHpA(TM89`0sFlAd3aUGZT--Y39SerVVuq`8>L6I|Tndig zUyFJ+`h(>3W7_MnkV=aDiNX*DT?a1e@`SYCY1u)`kW}@#dgZl zvlk79G}DV?DxgX1joJrvxayY#xXLdbt}CTOLUS>G&b>iiDQm#?VRtD1Rf%gfC!(CG zJ$@D^Qw76Zq^(Jm|Dl0Y^Ue_NlNn4*O&W5RYof#ZsiC*rrlouKv67`@nYL`g4~kb%ev5ZzV^ zWtTm`C(Q))7e#{kB^$iC?Ka`7e{WGY3MRg`)sSSw?3wXS*tBj0nl!x^c})o?10zLT zmhcUADNbVU^<1SV`Z@6Eb|of0PK6CxMj+H>)9IbZ$$-=&GFUMm#8w&jyHyp0Jls!s z<*HDdBnu-xr6cRQ9Mvjgct;8^x!E(EA%+};$hnhH{3HiyY#EsxQV#s5UQ}4)kA{j5 z>G&}Zh)_641XFFeSBf`d=A=5fs8I$2hYwUKyoSu_;^XK+MUd<`AUZLB4w$)C;Lt!W z+G@&!4bKG6M5mzoq}4RYHwC|%7l49|JMDa#g1MbZX!m3!nrXU0R?auBYb(QP=BVMv zXhTrg$Z%|a*r4~v0?3}~2bNcz;O4z#VjUlj(xc}BKUv%&nI7MwAL0lpx8)%u^(r~z zTLuy1X2L+H9_%i02m7o_GSK%}B<ZC8#v1jb5_e zk87sw$9{1!YIQqf+Rh9*u0J$uxAqa+)63xsLw9M<^Mc;IXnOf;2^2US zz%vYIsJKTFlDrM+#L*0ITgeIkk7*c_kMWv3jq*J8;KWyNNW1=y@Um+d z9j%UWs>LArC54Vp_L1B>*9rgZbuN`ECTqw3q84(ds8%EcL(k+P>iY>IDJkY|kn#qj zbBge}yqFIBu7W+a0*qKOn?8zig7?cSpsUv%)hnEl<#&w}HRlzHAM>4DO3sDm2W!!B z$0@F8+a4Uy`$?R?D4=%h2I}#~6V%9Zu#e26@1C_2M<$1DZMqa#w{1~->1^6D_7X9E zQUL7yuVk&l09DzkWj&0a==pD(KTY@*&HIUE-bsRTM z8(l>8q+8!2Q+@)!#7d(=1+hU-n^bdOU)*@&wG9f_?gTzE9g_4bb5C_kt z*gT*Lty?Xiy5}jwt99gNS1*R%bi(lA&NjbdaKVtAH1&D55<`^;3Y$Y{{l(oRx`oO6 zUZ{kEoC6q;rVXibLFjdyPg#ewY18k&Q}bV5sM~J@&LeFx=oY0)3+hpML;?&5hDH3H zf|l{5h{VG!GUulQmV7ryt8q(_wPiW+T6Ba;q@EF0)?U%GM|r5Oxe|x;&8Zq~qY-Py zkdF2BF~F)Qk2&EUx5 zVif565$*6Bgf;0hSN>5k+T;e&gc-{Ch`Gl_=A5K{c0Z^TbEj4K)R|Pw+Xuyoi(r#q zEL8U;n-QC1~@I7<)sTQ|K8NysV4>0*2x$InPdUs!=G=0eypVHO-3$M8FcOo7fX##ihGn3R@) z$0%c8V@IG3m*JQj{vhMb>p&%BJ9HelMBBCnqU7lp(a75c5M)+OKl|iT=?SBO*Ea_5 zY|*CQep>_q5;jtUIm6g1jSgp~MAz+d->WNX;Lt3qqa+GyGGLq7@RUavs^UlvrzVu><43-DX?1Q52# z5yeP5NSLw*9Ng`p@mn_$ecXVJ^|~mm^AlC1R{-C$le{^<5KivQhW%>}!j;nrXt`<$ zXbf+KrWP*bJAR^}GZNwHF^2c%*dXc%yg^Tn$iP;Ih1fH03iKRV4u^$T=qD?E)Z0Io z#9vA#IWrt_*`+vW+&IitdRvPQ4ZhI(ES$@>ujA&=8i^;WXTh`@4U`Qz2u9x)qpaCM zsLUM)-N*F6{QVkielVAW7HeZp))Xw3XZ(qyn@O*79hIDsN9`F$Y3`iYr11VMYB#zH z&)1mzpDNBg9Ln|o<6})DHQACS6WL3`5$66pDJr3|v{U4$q?KeT)kImcBnk;l$X4N` zBbod29BDb)r9wC-rzmMt;uI~v=ePXv{rxfjJ=a`wUC&(i{aIeGH}2|Nhj%iZaOZ<3 z#PfM2S>x>to-;C0^3?-{R(jw%Do4%)=Yf{chaSK6mGJCRY1B+be9%^ewd-4{uKHDa zFaIEPP3)sq(f7z*{Y2U?E5VR&O9Ae2Tz(Q)gPr6{T@|;I`xPp%bMUEDJ7_k%xqJdY z2a3t$ciE7_lo6p&R&;-4FY-l8;N@r?eg0-YZ13>L?z(Z(@fTvCjP)nK>YspJTkP=e zwpP*vD+wLQp(M4OnjSEsUp8(f;>$|J`=AC~n06GJCzp$Q#+)MquO3R{B<}daVhk=B z7l6N&#KVh-NW1~LP(86%G~KC<7T)mWFH*fR0t9~dFTnfL^e9lAJRQ7oCua4!tj;pq&bccI~En9lX(gq@UIem`D#!3 z2c=WIcoogI%7sR9Q%?UXAVlyTlk)>G|7DA9h1{%jYKyB}K zxVJe0oT@EQyCx7EZlyv=`zzYeu#MOhkHPNN-O}Rx9J(NH4+e~8q0E{D_}Z5OWkH45 zwz?cvah_zVRtBu*zQZD<1Xp;6gT?wXcpOxMp6o@b@aqoF*~~#+|6Hp1(+h_k)FI~h zYx4edAa>u`Onhr(aGo+JVso6G?<8(lH>e<8^+SOiUevt|@f<>L0mRR|L zkH;~N2*xC0#0&@EdA*l5JlaJBf4N{f_g<&$a)G{4WwK;hAm}i0Ao(+q2vi$LM(-+a z*KLd|nssnPC+9*(%VDLL0TlDPi0!;1#PX;c$8XPefc zO@s~Uhe`LAQgn)Jr^Z#U>2+}hiJf5x^7j4Y^6*)>AEOUyPTR4KV^W!UYavRqkm^1w zr6nrLxa{Ux%)F2cyvz^vN;clKF=;W_EZd2RA9P{8tQcbQU(ydxV^N(sgE>FVF-#(Y zU5C=JMk5QmJywhAMrRY*zqnah+HE3^y-(}khC-kN3D=Z7D?H6))X)N}5w~=8<1=DB9kC|wmJz%juS|i`%WA74m*}@@udqa_M=n$ zL1}IT$G|kmF~L{S1-YS7Tsj>%f;m zjy>6QoJ9TULaaF6=yMIX(FGs~VW&C7l4^Ep)(evkPxKD1z?IBjD?u)j}55*zdQ}F6rGMr3G2X9MTSgw5<<{VGL^WA6Z z&xBO4U*?FXUYBCuSamE5vZf~Lu~2qw8$GF;25z?QptO4mS0j}{*Q7JlYv>64Sj~B? zaY8U~2?oQRhTLv?9bGZW8a>6CqGjSNZqB}wG(UMQ<(>8>U+776Qu?=kc)?#(SRtfN zwNub4GgYL2;0)gAN(Ez*hM%6efQ76RR`P`K_;L+t5d2=>t+7MuncF25EBV*+6sL=p zNH$={(`X30wiygA$v}E?9pP`B%sCT5G`uPn@60&_Q$kkY#%-3Er+t-FCNGD8)me07 zhbN4j;DN&uj*rl&cjN^IUl|Q6#cMZmV1@2_c+D?^WycC>Zr}pw6D=USH0I;7kFn(9 z)qN1SXgvgUte^pt{-OQhX$W52Ea6#(G*~hbVQL0^y_OHPzV~Rc&S?@*a*hb#m9+hQ zA{ubJ|1X)xVz`wZ-i7vZU6?s(lS7WeT>z$vKFF(-RF-hVp=0=m~rU&h#j!SYkM zCoUFqf-+DWf*`%@47t609z4J43>BtJK|4(WT2*6U=OS4QjB|rUCHZhZ|0WSAOQ=qJ zC>BlPyu*7%cz!63OubzS=@D7jW%7tL*viwb?N4dmA9aLT9gT$rTcM7dMK{=dsF$ZyHyl@I)|~J6s6$3qxSR9B!8!J&)9B=zw8qC|aKggMe{HqME4|xc0Xq zkWKLf52H4+f~(E9zh}@#e-%c&VzBuU=cN4l4~f233g1lc(Ao_vvHE(hh>m`w1Lw|3 z{cd*AEcpef7nlxB`}MKp(miUo)*gZ`{!WBgoK%PQ(k#Kl)k8io!M(~@7vi>6?-i;D)`fj;{F zT{*V=G6o{*qcNw$5+C03`yAJtUUgnHGcOq~_$5GWt_Wl{9)t=v52!dX z77kg)QdyZ5;O|)hm+mHz^td=t_hL7ZpJ0?I{*2;kgK!Mw&O_S^HROe3Ha_ICXLuvY z(jEH_V^f=ed~(nMUg~M$zP=Qmi&)basSLZ>?`}GjKd>aXv-a-~CZy<|W&cfB@hauS2o331R2A#srqMGSRs9h)vfnDanKQtEFypBNf z*OMqr%%;2%U(E1`hU%gMk*AvkGp>e1$MqGEXlzZ@m*$b4kAD!o=jrgR^$P9b7-(MF zWhyI;0b#gEx?gb|HWw|z=C^-|@-9p8y_FU0G#IA!S2jbF{}TM+?1T2-m!WXKn3Qtv zZm@YSUTHf4ceZC@?r`L*LM|skdA>BS`2k7WF$MY51(d9e$kM@lENYyFPV%qnO(PcJ@118rV=b41V0)Hm zPCkYKdK%KeRtK3>UQmvJt8R0PMT z`+;bqI)*={M0wqI3})5n(Wi$%9OZ(qW*WC({dG3**`bs) zPt^eTaBoad$iZ;`>qIE%B!!yu!Rcs<=%TAU)DEeG8pqY%&@F@!9}`dtlOwJZtTE*~ zB^|PhVU(MVhkRCtoTX}{qS%$&rR3q3YZ)kTD5Nc_-f&7)1HY~5C+Uw8rT6lZAW(Zd z)OISMTPVksMr|YImvi9*7I92w8NQ6v!S**M-}2xpxh$v+%&UtW><`m<*lP zb1`qzAidV-gW~8++9cgfBgXo}f^B6u&h8muCar*!pXH?A%o%>)RtU|0Nu($>8M_;Q zP*XWR<@p-ZFHL6Pvv(_cTuG*yueftvaGkh+;Mg1EIC@Dfm)2f*K|c-ZL3VH=3f@=J zDA-RD0}kQNjxOoIZmtiodQY!}q;PfiSCN*sGU45NCUQDuPjk}bz}3y28mGLX(^?Lb zN8I@{T(5)8b6!$Oe=}J=<`~>o%)_VdvEWykO@6Ffi`^MUptmd*PA<;n@*cb$`}n?O zOYC7-b2SrR-8uroha*&?XN5B_#X;TOiO~JdT>7VWJ~!{ICxMmoL1jPJ(=jMjOIOA( zrBfhr)+EyKshHSnSt7G^E?DuGVDr8PV!FzLYHoT#lLrJ?Y&l3IS^m^w={lTHR{&7M z?alsZrn(P$$v)>S_@>cMC65(AT9=OYLvl1tUKeB3%dl-y0XBa+M61KHq=rsLXvQ9c z`-}EL;*VLh-NS$i7fzJgNi%RrHXT}1ztC0%3us+F4jRk_csnTt-%U+~)*e4J&bNfH zD0}Gt77Fs$yU2illQf58ed8yo;xw+l(JGzG?I!bZ>2FCe+c6Gj<>Z5CzA5@S|3-bX z%<#nOGBBS=;DzN8tP7h$2Fe#;{!k(3y4%sgY%lahZck%j4Wp9|!5SAgFwf>9g8nmn z=zp_7-NZ6$N|Nx$j|i;LcYwoC4&q1&O6IPU4oCh;b+{yn%X~4{nWJcIkDw=om1eMp$FzUd?o(>{I>reU+5<0+xFhYt%l)vpxcmb6RI+7B?Bqt9IXKAAmk8w2&1<6!%jcd+$=Is5#Z2D^Mzf_1A+AUcB2HdS%LXXOnD z@X`cnfi*K{zabNl;z&~YGnjK}H!=16E!;Wb4K4dPjlDfOnT@40S+Q7CWF=pQlCD}& zh`%=LTK^auMqH&PYb&wNqKAB*X9j|;O4#M4jRzK6GvY8?He=&YIOT8@cU)6tq6e** zfTkPN(?FRz1^B`QkJ~uY-k7zDI7(mI9RoI2ku|GVVpD230U*Ny{Tl6%T|5YPBw~zk zV4(-5;VPB8boH(h^x9<2oK4^}3P((trQ{+u?+{T<_5)lPG-I^4Phtb}kE3UsD^=rL zuzlmkK%#FSX}`b)oAewc!;Sq|HO`3nVswT~Z5Yc`J=bIYHLXX1|1`YS-HU%mJ;Kp5 z8f>%EFgoU3#vcz}f_$9{6i&PhL!*2~P$(rOF=N@gg@$0}WX+^JbAgb&QK*={19@*o zN$a2I;aaaDGwLa1?W=SkJW!i0bZZBD*KRcNoyhPKj75@{XXuaL*I;dxIr}|EkLgXE z#w5MdV@tlr!;-gEFelB3jkmc92ad}y{+SkxlY$Su)o#JKx<8_lKpA*z;)QdpW-v#3 zGcee3F=YHI$EMlTfXDDGOy`CL4H+rK(xe33SN7o(o5}1`4;vx~}4HDbf${;R@n&3=Xt zDxToyrcG$2_Kv>oe2qIlC=p&0wS>T zj+>kC@MKlS&iOm^FAKoyLoyIQatn?`R=}MU4c6WcY3}TQAUb>!6M4*#m6({2UlXS? zS_2dQHclqpr3; zY1BP=Hu#n@NIu)Zh-@zyufGA;SE;bg_3h;PguhV8fvN|Gl-Ysd54hw*I{uwh3TNJ* z#~sEqSns&Q+<@VChJ5P>DSGGX0TLAU2wqmPWp3{9QW{ggaHX}DOff1`s z5Is{+VXK`v*`@RV2=_>(L|z_eYx0>G)P?0w)!FwthU~f%h|y-_nB_SN%!CIluCeaI zJ*$4AOr0_lAJq+*wmJw-eiY|1-zc%`{=vbneftNvhT5skkf0 z(b`p?oy;Xf`W5-2&{{0z&FZKRQP_!Vru|=1wyF+)vXYFd27ipILg4!FUxWFwN^;ys d_rE`Zdi?)OX;oF@Pvjm``OnAr^JE1o{{>M`%Mkzo literal 0 HcmV?d00001 From e9c67355e3bbf6348de92a3ba7e53727c9db10f9 Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Mon, 19 Jun 2023 13:36:39 +0200 Subject: [PATCH 2/5] minor cleanup after the merge --- src/sparseml/transformers/__init__.py | 2 +- tests/sparseml/pytorch/model.onnx | Bin 23435 -> 0 bytes 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 tests/sparseml/pytorch/model.onnx diff --git a/src/sparseml/transformers/__init__.py b/src/sparseml/transformers/__init__.py index d96053fc1e3..83cb82e9245 100644 --- a/src/sparseml/transformers/__init__.py +++ b/src/sparseml/transformers/__init__.py @@ -48,6 +48,6 @@ def _check_transformers_install(): ) -# _check_transformers_install() +_check_transformers_install() from .export import * diff --git a/tests/sparseml/pytorch/model.onnx b/tests/sparseml/pytorch/model.onnx deleted file mode 100644 index 9d2051df3e4979da6e97c02b06350df34507749c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 23435 zcmce-c_5Wt*FS#9l&F&;Ns>xJrb;?zuhWT=kc5OvBodM$Dbz7!o-$8y%+jPua`rm+ zxf>)&QfV$tn%$b*D%JOTruTln&-=c=?|;AJAJ?9*b?v>@XRYbuRzXKQe{snJf~&71tg6i5G6vSG7#sEWEV zQ(*k}CdO7w!}VKtgfTtWn`n&rj~X4Bzt!ujD308?ImkO~^VZ-{z7$_dC?&72q$s;- z`_>&H>dce0zN(^9pzroz-=OuO{@x+J{1N;Se`_>WQ6|KDv(2@CDwPL%J;{Vp;FS(fgH@TP@ z{fAu4{#UvDqvrp2x%|h1{*sH?|0b7zRQyveoBk5Ze{@jMHV)akHOP2r(AEv!L5sqD zw|j5$b!1BZ67_`trG?Je|JF&-Qt|&ep2`?w=fCZdzb26TM}dy)-^&%PR8;;Z{?IV* z;IO|Mjr^}N9fg0@>Z>cthI z|22>NKTCAv{!ybpN>Ro;VsofEi`nOj)+#DngM%YDoK?Z0+jjW+?)F`8YObO9uLZ5Q z__xJr{A+Qe6{P-ic}|M|ZDz&)VPutY#(z!pU#C>@uR@)X|5dNJOlACk$!NXlUvD(( z-%9^KmiOzQzrRUujpF|_vEu(YJ|jcs^?khm?OVzJvqVSkA2t8C3_bqedxVkBe~hUz zmf7fk*DZg&n8Lq`bmae8rseV*E^NDZaH!?qm7*d? z@qZp%MeZM;QiJ{XNkTVoGIrj)Y0K8l8+GLW`H#xTzpME50h`10A1f=c6-NB+k1b{S z-?Y1}MET-MDe;E-T*y5=2R!Nq!PfpaSpG|j7q`-c7ks1(R00!md;Ce5@4txG_4zQ) zzcho_uUiI2OOt?AcbK!M`xjJ2wczmj#b7@#8q62QLcqyWV12p=>dT(upAQsyA&C>Q z86)2PMhqd*f8dw&nLJ0AwV=(J zC*G3J7MrqX^ZdIJb4^FVz>P8DlK3BZzIBFJ^{qbdWzHg=?dCgh-b9+GwYCdIkymjI zdoC}zRD;*+8jF`o@=>#2Pu#WUJ9Sya2buZacxlyqvFh_#ypY(rAQn9ZO}{AQzhVPx zyQ=8Vn~QmpuTx=BIiI)O%7Q1;yA0PT%J2$|*gV;eQ^+~{17fOVd5`yx;4PS`!TTkc z%{v_C3H5d_N%Q)RY$aP=`(aAxwXpIrv8+;3co?KYA0$aZMI zgIB%Bh#xuEL-iVEn0v+ty3H-c`cZAr(3J}pr%dGOFze_TwGGxDR>xQIhGH*;2pk}x zRJ}w(i(h_&Z%qZLy;fVyQaVKkg4SRNdoK05lO^h|xq%19zrrb*Q+ch^lffi!G3<{q z607CbfY5geK8}5j+4m2_O1TQO36r4SK~wDh@fcW}X5z^VOCH;7Gz`8Hz#m{L?%io6 z)=3nDG=H*qy`q%(#G=XKH_~QeL1%eOcWx}Ku{9L4%uk38&*9;=B^5Z=aRL^*bpY$S z6W+Ob2^-b+;OkEu-rVQPxZ&<)INYEr<~66lbiGlyPJcfdu383-+xF0X8V^xN(OCTX zJsY|Tx1r=}F43A+4|6@Hh=rl4qK}3*z#--{oaj0RNz+eK`MU$S=4v@gxQnQWHIAn= zYk~OFr*3q9YtHNHTL{g!3sAjBLW8T~;X&7Q5F}TNtdGP)X2g84&Z^h2&vyaO z)@8i-MDR~&oDxP9f<~k6)|(jaD=jwY)Z|GV&v9p^YVufWZ(0WTyvD3!eA*^I5;8_6 zVc|=F>l__kRSS<-RVODd-TV`wY7#FbdNuvhH=4JtVLs2K>i`ulO%nws%;M20T=A=$ zv6$tl2q_A0;pl=tVFt|Q-Md;1?5}4@LVrAX{dh-=3^c`x`|jhLgIU0q@`G{(5zebx z%qt7M0O3ncKw7x~M>t(X{`El;6=;C^lP8ELMqR@FlSc7OLoIoEo1S9(?pA2Ib`RA1 zY@skD9VOL2NXDh#!O`^rZ5Qf@7nrE=WCM+Oc5!ogn%Cu!r83OnO)bJRk#p7g67hjm^0<3EW zblle|yv4HHptbHguy%Y90*=bqp{xB6n>dZ5q}Mz z#EZHRiWy5{;H>>P-mp>uW;;JYg}r~$W5?9RKa5m(ruQcEl)f857n+IHi+`dt~9GO(u$;pI!{( z^ScTT z{Feh@E69@SiN8Zz4@^Ws!8h)kvq9jr#0}n!uf*<;bD@hKCX$*_$Qtu(fo9ea)%zxg z_p2gt^j&$_KGqn@4HIxKyN0ehT?x&(+Aw6|j{Lo^xvZ047aZ<-LsI7Zz-~hg2>Fta z64hkR3X5zwwl)cxjy-duxgRj>$yc zAO-SsCu5i0BU-XF8@irUie4Pb0F$B?(eTvAYa!uTGj03gf;o!O$Qofp(m!Q^YGnfC+z7@iaVe;Mokt^DX4CVBMj=P6 z0FqrDR8y{+-00Low+q&E!ymgKUp^EYqQfESts^~B5DUG>b4cTo5imcg7L?Xzp~}Zl zV2yC5Ha^Q~(7|5P-ZL63^yG97@JFWrJfxf;yfF<0vrkY#gcR3rPBgq+91CjVdK8!_bL1Ad zqtRR^BxU({#;gFz;$&oJe58WindqOl7<+Z4DeLfB8h0rSY4&Hzn)3Cmr2j`tomM7n zY-lBE8{4T~pDVe@N(B9KUnq6HPW)V-(XK;1)M7aZ!pJQAY*0tuPt1e*!;08zAwZ*?I_x!`0PRU0 zaDH1R2%KiNj5NyCEC*)aTSfQDX;fo}IrAoKn;i8rjI&J$zc+pI8ALQVOB;dh|^c)V?Wm&LpI0(jfy8uA$jNsO2BXXox6C%1tMoP4(sdfsGFMWxOJ4Znn#=;c){=t4)XRjw;RI=Fk4KjBER@b% z2}9HUX_b;SI(uxv1gRSaKyBp$NSi!F z`ETbE^ArV0C{==xtJyf^YkB>{yLdBWx9bbm0rgf7Gr>f~d$93+2WgmHZMio-uc;LfxSx`aJ zL3TnFta(=enR!JJd($2g#GByU)D&E>I~}|{&?0d*T2$Ilw{x@T@s2{sXqUw&9s8hP z%?3`Y6@%j^S&TMI0lkqMQAeN;*$M%mIV>SM^W#vE+(OkkrgZA@Yd5g zLF_dxK+gm>Odz{}s~CtMXKaR4c@1P~msqgX%5l=TeB5s14GOd8k)wb50$!{J)@Liu z(Kmj0dvZP;jnsssMn+dI?j(11O@%X!nRxhhHoojA2gSl_R1t55pRfV+HjV@Z^Fq8} zG(;q3S*RJ?OL-GUAeC)Ex=TO{9Yd+fh^wUQK?4!I^=r{P;DTxHO|(Kc9jeJ^(&A+Y zrIV0)bOgXn_hNE+L@8W2o`{2idodyI3awXH!M*7P;K8ZG@uxnM5ZA}tnK`+jSG1V2 zRZnx(SDB#o`MD^ud(H{04#CmOxFFY-hlLtrF*+(8Dhs#6h&L0^|A+_nhooRvO)EG1 zg98ja(ShQvveW;r^4FdL8;%}QuL15WVS?f-4 z!c=~chmQi0McTe ze3kwfBY-;(t*OMbnxn8lf!LGhr1!xi8c`cf{C+z_^?U+}gVh{#Kb8YZ4!y)eH-Knw za1-rG=L2g&32if1z?;uoh}7N;xK|beGOAVdN-+y}n|Z=S#~NsRWKQ+XN5QL{Q1bai zJP|0*;oe*tPtDb( zBFV&4bWBPXiicJ}eB@;kHY)~;KjdI`z7I$&FLQMq?Z8WZKmD?*3U+BWfYBlY98r>p za%T)cPgITD8T^~DZaxI8p998iMPSL;fX2z!Y4+avFm`S+Xde!veW|6m<9Zf&n#$q^ zU46J}UI~8J765xk3EB^(QA__CG*Lf8-!OiLglr*#^NKC9w`*XaM-AMU)Pg9}5W}qA zk)auD={(0W7@qr_@a2m*tcIPUje|qPY@+~Yd6t0!^(3xY%y+cNg!X-rU^Ghx&0cMT zQA;afta<@>w)Bavph;ss&5sYk9e2~A`11s~vOovFWqJT#X(iff zxq!*e8ECoK1D-Bg336$_Xk*$p?zB~-QTTWj-Fh___x!dFg#kX?X*Wh-3K)Up&vwzL z7kjZ~?rdn99g5e5h0wiE5A4mWY1Ef)Dp6yj`sIma8aDxUsHTJZI0O2i=qbrKWC-08 z7+vI_jFNjMRH-4GmU!~9b($4CtT_OzPqw15-C3ahrvWV*Uj%YjZ_}=E-$em|9#9(c zkqjJ36A{T9BKxQe-k#fnos-nTZ_+Q?)fA4|GRr|qvjV)%8eoS0G|1-&ahGxf^nTRD zR}NYv;d&uv?<0_EodzCT*Tda&%-I~Owc0bg6 z%|hLKD?#Aw($de-Lf6)2I?r_l&_T%n=w6Y#-s z0}cKVjqHOdqO_bd^o~dw%{nX~dVUh)J6{()?`JUN_$qGts083=zutP!!V;X|oRSpdPd`2iQl@a5U1!j!x&qo9HJ+@S zRt#@GS-_{%Ik<9mEVOpKp+loB2<`tuSuY1fcXs>Yqvf;k^5!gv@QC#MPPhV|>`WL1&m}3u#_6I>>w$g~}0 zXz%rz+FRDqo>|j@GioNNePR5Qjkaii_daoKK1pT65*u8WnoO)JD>iN@Xu{aFA-JK7tSp^)I zUG+E+I2xMkR-mg+7=zmu#Deq86qM;qK;OgF z@O)w}9bTMI_0-sKR%0K8*r$*~qcg!Jz7|JkEPzfOTkMot3~RlsaA@QyAp8er!BWL!By0R672aOX@uX4vmSYbC&MM(&_B zc@vJ7o{W0StjNu!$#e!4!lVnaz~>Fn*(rGtV3UMzXJ){w?OsHN?kC&t@FDSP7N)La zaQ5R5bSDvDyn6;}$2!w{c8ft0{DDY{vS~u|deBbtryVY*sOHyHn$V(f)9#nug^ z@2!St$5Q6IW-hQk@`Ku0+TlmW772bW6CE}xr48TKz*|j%pA;9P{y{egfmd{#mm0=B z+l?o}b-_!2BdV9Nq598c`bBp?T$xo1YjXz4;A$s$^>GG<%x&ey%@JbSoC6r5w;Z#Y z@~|WHCym#+O$>($X_dw{tY73#CAUmzn>8COzs-c)$>n5cR2j0_d0dOj%A_=k(a(2B z0Q<)ikz(jFNQ$08`KPNnA$vDr<%0winsd0#=j;eeUwVO8ZW&d1tq;nf$)K@#Axs;^ z#}Rk4F>cQR)S1cT6P`_l?9bj{X0`!(`8{;!sWRAYk%o_gQZPO@mmbZX2&tU)sMq94 z;-|LGZ|R@2T_kVB7Jd(sf+0?Qg68d zNqsFHTDg$&=B!56m5*mlp4_5`Wb9DzZk$NS;v&EIGRO11B6i+(!u?JKQ0b!w+Vj%U zBBCCzwA3?u$(-0J)q?DzT)6X2hE5cwLStbH5#||)EJj6<#qTc>J~=DOO3k4CwQJE| zYA=N>r+$lxo0wFJ$*sr zV#lKNi)~=fYNp3tS3%b-iOAL62x!w&;*l~DHvT+D+FMN^&$$95gK?bqnla$Pp910O zNcW_zg`g3oc+@!_wc6Nl>v|ou|L}(RwRy;z!4ZXA&%ozKHPp>r6$Tc{!$L~|>SY*W z$g7VuVYL~u9xXX*Q9X&#C#A&ja3y(PS^#19M5I?=TIBfgC-=(bNjSTs9+vMu0QFWn zwCaE|9-CPUt^KQ!?bsvAk~uG`uh1akrZMoUb~AaWc9b}-szdFEmuSiN8klBz08;Wz zfvwyxV)18^-6}z_<6R6~nHr1(oIn(RR{^<)nxNg}A>!o+L$fR!GY)OQ>IVl2t9cG- ze3L8c?HWz&m^gu~N(Ss={M%|$LhV>qsBXw6w{7y_;Unv)nHQop6dBIWBBVexcz_|n8>iPQ_U3u2h}jRb38i7JK_=p0elQ7#MX*V zN_|d}f!~TnERTaNg@+t*L1`}Zel$K?FB zLJYoSiE(J$P3vRS6z!%!y+0kd&rl=mif(~k~g)WVBY}>{n z4>rCdVJjPskW>et#-iJg$cD6j|?;!Q<&L|1nNBB+~sn|RKow*fg z`)Uh#p0owQvNxixL1+9pDFk|lHsj~t2Z`l4Ik@tgV6!lk_U(>D=Y10}-(xXO{g4E+ zs$${F%DJdl?tl<nrO)z*^pt z{VNdEef){z&3bO`v>U|k&Q_G{YG5Y+#-rHv!FaT z2KFSVgXF9nR_c|YVb(CMy}uC5xi{(09a8XWiXyF@V2Q2YSAg8o>F7CM4P&4XbXJsO zP}DKvJ2wudvZ_GL+Q#T87B#x+3|R(8xiLO*u;-;aWH9*EV=@OEV;Z>Dc}Hk_iU~wK z98I3)GPb<&9!--@#<0(+2jzgQF@e)*j0l=>qmp*9dle=m<(~|sz@C75j)ouIGlNuj!}tWa%V1NShbx@`jH4a zUePcwGZ{DC;Q}o7_k%~h zj7_dMLr?UCLsrjKGQ{Z7bX_3`higg0o>Yt~YNe8rog`t7GL$Vg0?u&*=p8XkY~r;^ zU!W=4c*LONTOsWovw_J+PRl8G@6<;J4# z!FuTEd`db$RA9mOeDr!7ENZk+hJkB|bi_0z9602T36qv%M7s8Syi_E_rwXDDqr zz6ID1+eDJ@cS#a|Jw35%Ht5-mfwz)%SbnJhHOy>){(Mc?eR+gGw3@{4X~6PtRd7?1 zN%`J)TN2(q@+qnSDf1d+ub@R}ADS*M{4D4Mu10`0{Aow_)s}bZ0 zJ*G&%%PZo8-SO1>Mh!?s)f(oA!@ZHz1)aJD=vgVg^rMl`t`Bqzlbhm zq|<7p9@4F+4}!1b$hQVdaOh7)BY_(zT|Y(k9A%;Xt#gFGcr0#MqyTS%XG1;TmyXMG z$0)x)NQT@vygn}u@*~@*!iL$Dw`2`|+S`DA6F*Vc$(q>XFGM~0Wkhmz9Xh5yrdnD4 zm^S+|ReE1V_-UuO6>?0BHRGG;%t-+%dE{fr;2~nNsZrD(=LW?s^1u$M5Z${H3WaJb z$ko6y<_xR@#dpCV$M{9O`EnrRH9)?t@`P+@H#n~;i(UiwDeGNR^XU)e@Nj(z2#1v??$U6e{`Ka1=#9iz*EBp7yeO-ciwEGB)W+BZ=8(& zHu9)6<0E-~A%|)X$5O?|MsRs|8uZ8(dy%J{_Y514fy2aE=$AA0LVag8jH~ z@j(#U%HgTa17wZcT2y(#gP{pLWPPYw(E7pv1Sh>(c=Af1m%fH*r0c-TrW`bSvJ~D) zXW`!8V<5F_FW#6t6;1RT=z!8#5M)O&c^d<$EOmfbS7%^d?BqDF^hdv+JF)pz2(~F` zpizeh8eH6l15rs>@*xI8zP#mzO#8&G(XWFq*HS?uxZBbfD~q~{G8prjiBqz7&^=@4 zz`*QwQJLCu9As(EO%g4){dg=RTbTz9t|nqRD?pEgavPwx5XG#mkgX;BcIvK1^2j3+HQrC9pD zo`|Ngv5~eA!LoTRQhiwvG*d(;d`pC&-wu$c-s7OKU^cz-J`LQA6+wURMCfcPM6qTy z=DB8p%~5&Ub@DA~I~qZc?MOn+hefn>NeOj)yAT6*)G&Dk6H)LBd#u_Mg&nIe(C@Pd zT5T_9e8mj#i~CMfbqaw0>M?b=HVd}a=zD~MROq@9(|h% z@H6#K51~ID6p^ZYq|QrLVU1%ROxtUO$6n+kr_TkRMiMa3wPG;lJkUPB5XIF}*xEl4 ze!4Mni5g$g*Qx^edEOW$A6?KfRB$DLTAvO;)#wtC)Q0kQ28)Be zA^h?RXqu}GJ)GM#Td@#r=Bbf+X*nQj8HMb46HIDafv+5GY441DdT35QvT}jzbzYw< zI37;=GEPz7l|{f_Uyb9G51`V40lLG2iB;BJyLdIUb zow^;%ycJMiWgC=uuff5?mJt47HS}eLf%xh^41Y2k>#gL8%4-MU@-0vTwxmRR1I8cQ zk5A;7n76$O+MmuM>P7mrd~q^N9ZrNc&w1E(DV&-bFkB|fC)7)RBZk;Fi42#%CfR|@ z!9UEOZg$LuNqUnot@bieOZ-I(XBksnXEx3;D!^*Vel#+7f{^_~+@<%jU~_Cfuy%>L zSp|P`9k+?N>-OivyTa>a*_k@n@W=*AgNn$Y(Rv(_u@k%MvdG)1X?W5m8cQNdP?#Y_ zBoCfbVf}G(t_)!D4Ke8*)j>C<<$>BBV_fW+K=@KQoTky&>4P&Lh+67J+Ull?{MJ8O zUTKdfpN!XI_;+USUr55QlT)F%P#N0Yg2A`G00IMxpndi-;IF<-@9>vX@9|mCo~6XZ z6qe9YclM$EFA+WDp^OGjd3bY09TBWuLRtRTsoY6DJSWYZzx_EVcw#4dl_Eg(awTH% zQVq2)PbPw@Ib7lFWn9;#JK-deV|;)z!uOLW_w1IyQuhQfvyOnXE^#p2n@h$=-Y0_} z@6exjccAX*4tm}w7K@!$g3Yoqbob_5v@cA?3dVQH%B-jT22(H+0GEDDfX3eU}zNZuoWb%cKm8ZRKIeWjfVpS_1-CX|m_C0VoE}2DiaA7#ynsmu51#=@v(- z8+(!3UpPj~63x)$IETTUj17BPNT+}3CallbMe{FZf#02K^r|{UPKA9Ub!0zYP>Uhf zcWppdK^uqOI}xAtJFzu;CP@50wdBsN!p}{~)Mo1{(vdBppDK2u)xu1y-pgQ}xNBSq z7IU_lmcw(G2E60qgIb4n0{?=L4tly^WR@FB9nOXfi7N)RG5(&(cuabrN({1RV9(** z&@jdnlHJq6Nz()52ea|YL_b_Q@&NL#u`p|OJRKicO8QRr)0CoZV5Z^<5q`VqLWU>j z;+KtUeanf!_9^%61v3;B$%wo+WkZDK5;}AHKKQ*XnZY+RxZS%~f!6tae3c{zJ)SF| zu00#|9&)Jr)Dj2|V1vu{0vd8N0Ovo9hb|K-hMW5Sgx-MYM;egWRVbXpyf*e0G@O`J!dmJ2Hdv;`Fi6x|+*kV$AgDvP^n2IY?lOQ2^6gDo);oflAf+gWu_)Xdjoa-{6 z(sCMXYt90RR-I^WaUD&GF$8hRE}Ynw2d}Hj;Y74CRQ?f*IS&$GUabFj<~wToM++xsfiesG$wdTh8ne>28seM;@F>;u7v zb5zaB8lPU~!l%FG8Pe9E#Zn>qn9j#nXGhYkaRr3G;|fQ2!Z>hysgK{&cH-F)abP94 z9rY<2-dSFtywxh0;58Aw^g77R_)KCVVc|xLHljU#y2x#nGKe2KqW8}tNHFw*2QTWV z(>XtAU$7JGZx@h(Y*`3TEX2|ewfI;g4_zLnQj7DOi00fF6c=frP;G!lJ(LikT@aF4 zZ|Uy+)flpIB5*n5v2Bqh-E}GtJg3YBb?QXe&R0ZwpBB-$a#i#gYX;_#V*zZkK=Z~I zn%$%iHv`)ku4_F}k@kdu1-{VufH}8!Dj>;aA+jEXw#ZoKV6`K|t5N)u*zL@~2I)yq zFpz^+mM_QY;zZ&#Est38_CnB+Oj4P@0Pas`aPyL%)Qp6JRdF7gm^`Hd=>U--{|0Rf zF<|gEPt-Gc6KpSwfD?i}z`pX8t2IpqO^35!aAgs)$6pnBJ%3LW>U^lSuPTNn&BTvU zMxeB+7Wr#qNYBKnz|XRxnT{p6=3W&_Gx7DlDQn=pVI7#SS^#f(*^u`<4_L8xS{~=c zLcD!HVI3`_jY=(C*R*g@n63()Dk1QKiG`V~kHodZ`KT^QL$CJ5v}egW(0=)d3{Nkn zA)|UlpW@<>1@8DwG!{0r1jECo7;s#eNY7m~fcL%ofSb1f{iQb%&nb)0QV<5NHv7?W zpp|;PnI>}47m#O6y!_Ge2o(O|D>Cuw;r3pgiM^iJ=#L^Mci{hw3SX|Hy|GC+T|b35 zoiT=27q#GnVkRuu+Q4v`7>q{80nbSh_@5i;o~xN?WoCto7U#he^;&4GzbBHId~9J? z=8L)(^wMI}%~0834ngA>-~Y)Z&Z@xza1PMH?sWms_S_WRo}1B}pnZ5=E(g@7X_Ckr zvr*k!j#}ojLCh9n)#)0z|1c50Y)OJ$uVa8LE+jUXh6C}JXuEk5ZhBG#a~>GtBj=IG z5=3#1vgA;Aw+}o^T8LgwGPJMcC!La*4ANn8aGw(gCbmu}7;}yYvc|W#1}s2fnGE`X?2pxf1QI!6L2jJt%Kbiw`sM z;8766AzVL5v~#vm^{uK@4Q|j{yP2@Xtpr$P6p2h;hlgU)@Y_Z|(A(ohGR|6JmS|;0L&EllbP}WM)i=4*%k`;&OGe(F)uz77X z_W8|(caKhx{t8XF(XtNNYcA3JJ#r}RZU6=qdMI#7r&q*fC`p(x z4O%^lgM`5KFnvNLu66FBowb!1p)7^+%DFK84~Dz2={AjJ$6&&}9Vi;G#rzYKpvS`> zbkPdbTW5m#ip5})u$Qnj*KwZ)lpss%BrU2)fxt~O(D}eSuqf-w zVXE6J0}9VK6ZX+!u7GEZy+wT_gzJbK%*|nFQ#kgQtcI5>8gRZz0<2fe0P}crd>XM6 z=$mX}Ll$G&jbEg9s{=RO`2jWSRff(YhF3Sn3m$L6@s;%6YoUUOOR_rK zMeL7PL<8#?Ojnu7aTv|T1mhgs!?K3T6S>IRszg>>*WrPfIvC1L#lqShC^AJHA7(?QU zc45Y=Ik4#k!?iz@M*R#mp=*7s=;CNEx_&QX&d7tw>gj zIxJ?#Q=a!SuvM4>{4>QZ;~0CJ@UfHfS!pdJ*XMyvS0eV-t8!VtOt@!`Rm0KE9?;Wj z36`H{L4QjKzU__2tcRz#RaeWPzEzr-cxi!7p)Wj}y#{qtW}$752jrgUAr{)}$v{`5 zNZZs4`Aalue~T^(Chs8qhYgr>G6o8xedzqng}~mOMkPvFct^vHi8-maw8Yn;dZZ26 zG^zp2d!umh%r>-%TZC>WG#I`-!v*+b8vHUVhMR{D(Ej=gY`wG-{C{qtj;`t0-sl28 zO#ZSVZ3}FO9|MDByV3unHpA(TM89`0sFlAd3aUGZT--Y39SerVVuq`8>L6I|Tndig zUyFJ+`h(>3W7_MnkV=aDiNX*DT?a1e@`SYCY1u)`kW}@#dgZl zvlk79G}DV?DxgX1joJrvxayY#xXLdbt}CTOLUS>G&b>iiDQm#?VRtD1Rf%gfC!(CG zJ$@D^Qw76Zq^(Jm|Dl0Y^Ue_NlNn4*O&W5RYof#ZsiC*rrlouKv67`@nYL`g4~kb%ev5ZzV^ zWtTm`C(Q))7e#{kB^$iC?Ka`7e{WGY3MRg`)sSSw?3wXS*tBj0nl!x^c})o?10zLT zmhcUADNbVU^<1SV`Z@6Eb|of0PK6CxMj+H>)9IbZ$$-=&GFUMm#8w&jyHyp0Jls!s z<*HDdBnu-xr6cRQ9Mvjgct;8^x!E(EA%+};$hnhH{3HiyY#EsxQV#s5UQ}4)kA{j5 z>G&}Zh)_641XFFeSBf`d=A=5fs8I$2hYwUKyoSu_;^XK+MUd<`AUZLB4w$)C;Lt!W z+G@&!4bKG6M5mzoq}4RYHwC|%7l49|JMDa#g1MbZX!m3!nrXU0R?auBYb(QP=BVMv zXhTrg$Z%|a*r4~v0?3}~2bNcz;O4z#VjUlj(xc}BKUv%&nI7MwAL0lpx8)%u^(r~z zTLuy1X2L+H9_%i02m7o_GSK%}B<ZC8#v1jb5_e zk87sw$9{1!YIQqf+Rh9*u0J$uxAqa+)63xsLw9M<^Mc;IXnOf;2^2US zz%vYIsJKTFlDrM+#L*0ITgeIkk7*c_kMWv3jq*J8;KWyNNW1=y@Um+d z9j%UWs>LArC54Vp_L1B>*9rgZbuN`ECTqw3q84(ds8%EcL(k+P>iY>IDJkY|kn#qj zbBge}yqFIBu7W+a0*qKOn?8zig7?cSpsUv%)hnEl<#&w}HRlzHAM>4DO3sDm2W!!B z$0@F8+a4Uy`$?R?D4=%h2I}#~6V%9Zu#e26@1C_2M<$1DZMqa#w{1~->1^6D_7X9E zQUL7yuVk&l09DzkWj&0a==pD(KTY@*&HIUE-bsRTM z8(l>8q+8!2Q+@)!#7d(=1+hU-n^bdOU)*@&wG9f_?gTzE9g_4bb5C_kt z*gT*Lty?Xiy5}jwt99gNS1*R%bi(lA&NjbdaKVtAH1&D55<`^;3Y$Y{{l(oRx`oO6 zUZ{kEoC6q;rVXibLFjdyPg#ewY18k&Q}bV5sM~J@&LeFx=oY0)3+hpML;?&5hDH3H zf|l{5h{VG!GUulQmV7ryt8q(_wPiW+T6Ba;q@EF0)?U%GM|r5Oxe|x;&8Zq~qY-Py zkdF2BF~F)Qk2&EUx5 zVif565$*6Bgf;0hSN>5k+T;e&gc-{Ch`Gl_=A5K{c0Z^TbEj4K)R|Pw+Xuyoi(r#q zEL8U;n-QC1~@I7<)sTQ|K8NysV4>0*2x$InPdUs!=G=0eypVHO-3$M8FcOo7fX##ihGn3R@) z$0%c8V@IG3m*JQj{vhMb>p&%BJ9HelMBBCnqU7lp(a75c5M)+OKl|iT=?SBO*Ea_5 zY|*CQep>_q5;jtUIm6g1jSgp~MAz+d->WNX;Lt3qqa+GyGGLq7@RUavs^UlvrzVu><43-DX?1Q52# z5yeP5NSLw*9Ng`p@mn_$ecXVJ^|~mm^AlC1R{-C$le{^<5KivQhW%>}!j;nrXt`<$ zXbf+KrWP*bJAR^}GZNwHF^2c%*dXc%yg^Tn$iP;Ih1fH03iKRV4u^$T=qD?E)Z0Io z#9vA#IWrt_*`+vW+&IitdRvPQ4ZhI(ES$@>ujA&=8i^;WXTh`@4U`Qz2u9x)qpaCM zsLUM)-N*F6{QVkielVAW7HeZp))Xw3XZ(qyn@O*79hIDsN9`F$Y3`iYr11VMYB#zH z&)1mzpDNBg9Ln|o<6})DHQACS6WL3`5$66pDJr3|v{U4$q?KeT)kImcBnk;l$X4N` zBbod29BDb)r9wC-rzmMt;uI~v=ePXv{rxfjJ=a`wUC&(i{aIeGH}2|Nhj%iZaOZ<3 z#PfM2S>x>to-;C0^3?-{R(jw%Do4%)=Yf{chaSK6mGJCRY1B+be9%^ewd-4{uKHDa zFaIEPP3)sq(f7z*{Y2U?E5VR&O9Ae2Tz(Q)gPr6{T@|;I`xPp%bMUEDJ7_k%xqJdY z2a3t$ciE7_lo6p&R&;-4FY-l8;N@r?eg0-YZ13>L?z(Z(@fTvCjP)nK>YspJTkP=e zwpP*vD+wLQp(M4OnjSEsUp8(f;>$|J`=AC~n06GJCzp$Q#+)MquO3R{B<}daVhk=B z7l6N&#KVh-NW1~LP(86%G~KC<7T)mWFH*fR0t9~dFTnfL^e9lAJRQ7oCua4!tj;pq&bccI~En9lX(gq@UIem`D#!3 z2c=WIcoogI%7sR9Q%?UXAVlyTlk)>G|7DA9h1{%jYKyB}K zxVJe0oT@EQyCx7EZlyv=`zzYeu#MOhkHPNN-O}Rx9J(NH4+e~8q0E{D_}Z5OWkH45 zwz?cvah_zVRtBu*zQZD<1Xp;6gT?wXcpOxMp6o@b@aqoF*~~#+|6Hp1(+h_k)FI~h zYx4edAa>u`Onhr(aGo+JVso6G?<8(lH>e<8^+SOiUevt|@f<>L0mRR|L zkH;~N2*xC0#0&@EdA*l5JlaJBf4N{f_g<&$a)G{4WwK;hAm}i0Ao(+q2vi$LM(-+a z*KLd|nssnPC+9*(%VDLL0TlDPi0!;1#PX;c$8XPefc zO@s~Uhe`LAQgn)Jr^Z#U>2+}hiJf5x^7j4Y^6*)>AEOUyPTR4KV^W!UYavRqkm^1w zr6nrLxa{Ux%)F2cyvz^vN;clKF=;W_EZd2RA9P{8tQcbQU(ydxV^N(sgE>FVF-#(Y zU5C=JMk5QmJywhAMrRY*zqnah+HE3^y-(}khC-kN3D=Z7D?H6))X)N}5w~=8<1=DB9kC|wmJz%juS|i`%WA74m*}@@udqa_M=n$ zL1}IT$G|kmF~L{S1-YS7Tsj>%f;m zjy>6QoJ9TULaaF6=yMIX(FGs~VW&C7l4^Ep)(evkPxKD1z?IBjD?u)j}55*zdQ}F6rGMr3G2X9MTSgw5<<{VGL^WA6Z z&xBO4U*?FXUYBCuSamE5vZf~Lu~2qw8$GF;25z?QptO4mS0j}{*Q7JlYv>64Sj~B? zaY8U~2?oQRhTLv?9bGZW8a>6CqGjSNZqB}wG(UMQ<(>8>U+776Qu?=kc)?#(SRtfN zwNub4GgYL2;0)gAN(Ez*hM%6efQ76RR`P`K_;L+t5d2=>t+7MuncF25EBV*+6sL=p zNH$={(`X30wiygA$v}E?9pP`B%sCT5G`uPn@60&_Q$kkY#%-3Er+t-FCNGD8)me07 zhbN4j;DN&uj*rl&cjN^IUl|Q6#cMZmV1@2_c+D?^WycC>Zr}pw6D=USH0I;7kFn(9 z)qN1SXgvgUte^pt{-OQhX$W52Ea6#(G*~hbVQL0^y_OHPzV~Rc&S?@*a*hb#m9+hQ zA{ubJ|1X)xVz`wZ-i7vZU6?s(lS7WeT>z$vKFF(-RF-hVp=0=m~rU&h#j!SYkM zCoUFqf-+DWf*`%@47t609z4J43>BtJK|4(WT2*6U=OS4QjB|rUCHZhZ|0WSAOQ=qJ zC>BlPyu*7%cz!63OubzS=@D7jW%7tL*viwb?N4dmA9aLT9gT$rTcM7dMK{=dsF$ZyHyl@I)|~J6s6$3qxSR9B!8!J&)9B=zw8qC|aKggMe{HqME4|xc0Xq zkWKLf52H4+f~(E9zh}@#e-%c&VzBuU=cN4l4~f233g1lc(Ao_vvHE(hh>m`w1Lw|3 z{cd*AEcpef7nlxB`}MKp(miUo)*gZ`{!WBgoK%PQ(k#Kl)k8io!M(~@7vi>6?-i;D)`fj;{F zT{*V=G6o{*qcNw$5+C03`yAJtUUgnHGcOq~_$5GWt_Wl{9)t=v52!dX z77kg)QdyZ5;O|)hm+mHz^td=t_hL7ZpJ0?I{*2;kgK!Mw&O_S^HROe3Ha_ICXLuvY z(jEH_V^f=ed~(nMUg~M$zP=Qmi&)basSLZ>?`}GjKd>aXv-a-~CZy<|W&cfB@hauS2o331R2A#srqMGSRs9h)vfnDanKQtEFypBNf z*OMqr%%;2%U(E1`hU%gMk*AvkGp>e1$MqGEXlzZ@m*$b4kAD!o=jrgR^$P9b7-(MF zWhyI;0b#gEx?gb|HWw|z=C^-|@-9p8y_FU0G#IA!S2jbF{}TM+?1T2-m!WXKn3Qtv zZm@YSUTHf4ceZC@?r`L*LM|skdA>BS`2k7WF$MY51(d9e$kM@lENYyFPV%qnO(PcJ@118rV=b41V0)Hm zPCkYKdK%KeRtK3>UQmvJt8R0PMT z`+;bqI)*={M0wqI3})5n(Wi$%9OZ(qW*WC({dG3**`bs) zPt^eTaBoad$iZ;`>qIE%B!!yu!Rcs<=%TAU)DEeG8pqY%&@F@!9}`dtlOwJZtTE*~ zB^|PhVU(MVhkRCtoTX}{qS%$&rR3q3YZ)kTD5Nc_-f&7)1HY~5C+Uw8rT6lZAW(Zd z)OISMTPVksMr|YImvi9*7I92w8NQ6v!S**M-}2xpxh$v+%&UtW><`m<*lP zb1`qzAidV-gW~8++9cgfBgXo}f^B6u&h8muCar*!pXH?A%o%>)RtU|0Nu($>8M_;Q zP*XWR<@p-ZFHL6Pvv(_cTuG*yueftvaGkh+;Mg1EIC@Dfm)2f*K|c-ZL3VH=3f@=J zDA-RD0}kQNjxOoIZmtiodQY!}q;PfiSCN*sGU45NCUQDuPjk}bz}3y28mGLX(^?Lb zN8I@{T(5)8b6!$Oe=}J=<`~>o%)_VdvEWykO@6Ffi`^MUptmd*PA<;n@*cb$`}n?O zOYC7-b2SrR-8uroha*&?XN5B_#X;TOiO~JdT>7VWJ~!{ICxMmoL1jPJ(=jMjOIOA( zrBfhr)+EyKshHSnSt7G^E?DuGVDr8PV!FzLYHoT#lLrJ?Y&l3IS^m^w={lTHR{&7M z?alsZrn(P$$v)>S_@>cMC65(AT9=OYLvl1tUKeB3%dl-y0XBa+M61KHq=rsLXvQ9c z`-}EL;*VLh-NS$i7fzJgNi%RrHXT}1ztC0%3us+F4jRk_csnTt-%U+~)*e4J&bNfH zD0}Gt77Fs$yU2illQf58ed8yo;xw+l(JGzG?I!bZ>2FCe+c6Gj<>Z5CzA5@S|3-bX z%<#nOGBBS=;DzN8tP7h$2Fe#;{!k(3y4%sgY%lahZck%j4Wp9|!5SAgFwf>9g8nmn z=zp_7-NZ6$N|Nx$j|i;LcYwoC4&q1&O6IPU4oCh;b+{yn%X~4{nWJcIkDw=om1eMp$FzUd?o(>{I>reU+5<0+xFhYt%l)vpxcmb6RI+7B?Bqt9IXKAAmk8w2&1<6!%jcd+$=Is5#Z2D^Mzf_1A+AUcB2HdS%LXXOnD z@X`cnfi*K{zabNl;z&~YGnjK}H!=16E!;Wb4K4dPjlDfOnT@40S+Q7CWF=pQlCD}& zh`%=LTK^auMqH&PYb&wNqKAB*X9j|;O4#M4jRzK6GvY8?He=&YIOT8@cU)6tq6e** zfTkPN(?FRz1^B`QkJ~uY-k7zDI7(mI9RoI2ku|GVVpD230U*Ny{Tl6%T|5YPBw~zk zV4(-5;VPB8boH(h^x9<2oK4^}3P((trQ{+u?+{T<_5)lPG-I^4Phtb}kE3UsD^=rL zuzlmkK%#FSX}`b)oAewc!;Sq|HO`3nVswT~Z5Yc`J=bIYHLXX1|1`YS-HU%mJ;Kp5 z8f>%EFgoU3#vcz}f_$9{6i&PhL!*2~P$(rOF=N@gg@$0}WX+^JbAgb&QK*={19@*o zN$a2I;aaaDGwLa1?W=SkJW!i0bZZBD*KRcNoyhPKj75@{XXuaL*I;dxIr}|EkLgXE z#w5MdV@tlr!;-gEFelB3jkmc92ad}y{+SkxlY$Su)o#JKx<8_lKpA*z;)QdpW-v#3 zGcee3F=YHI$EMlTfXDDGOy`CL4H+rK(xe33SN7o(o5}1`4;vx~}4HDbf${;R@n&3=Xt zDxToyrcG$2_Kv>oe2qIlC=p&0wS>T zj+>kC@MKlS&iOm^FAKoyLoyIQatn?`R=}MU4c6WcY3}TQAUb>!6M4*#m6({2UlXS? zS_2dQHclqpr3; zY1BP=Hu#n@NIu)Zh-@zyufGA;SE;bg_3h;PguhV8fvN|Gl-Ysd54hw*I{uwh3TNJ* z#~sEqSns&Q+<@VChJ5P>DSGGX0TLAU2wqmPWp3{9QW{ggaHX}DOff1`s z5Is{+VXK`v*`@RV2=_>(L|z_eYx0>G)P?0w)!FwthU~f%h|y-_nB_SN%!CIluCeaI zJ*$4AOr0_lAJq+*wmJw-eiY|1-zc%`{=vbneftNvhT5skkf0 z(b`p?oy;Xf`W5-2&{{0z&FZKRQP_!Vru|=1wyF+)vXYFd27ipILg4!FUxWFwN^;ys d_rE`Zdi?)OX;oF@Pvjm``OnAr^JE1o{{>M`%Mkzo From 5e071e4fac2f7782483f90186bff4b790a8d4589 Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Mon, 19 Jun 2023 15:07:05 +0200 Subject: [PATCH 3/5] remove the old file for opt embeddings adjustment --- .../position_embeddings_adjustment.py | 179 ------------------ 1 file changed, 179 deletions(-) delete mode 100644 src/sparseml/exporters/transforms/kv_cache/position_embeddings_adjustment.py diff --git a/src/sparseml/exporters/transforms/kv_cache/position_embeddings_adjustment.py b/src/sparseml/exporters/transforms/kv_cache/position_embeddings_adjustment.py deleted file mode 100644 index bb38462d1e2..00000000000 --- a/src/sparseml/exporters/transforms/kv_cache/position_embeddings_adjustment.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from copy import deepcopy - -from onnx import ModelProto, NodeProto - -from sparseml.exporters.transforms.onnx_transform import OnnxTransform -from sparseml.onnx.utils import ONNXGraph - - -__all__ = ["PositionEmbeddingsAdjustment"] - - -# name position embeddings weights -_EMBED_POSITIONS_ID = "model.decoder.embed_positions.weight" - - -class PositionEmbeddingsAdjustment(OnnxTransform): - """ - Base class for model architecture specific transforms to adjust graph - to take input_id positions as an argument rather than computing them - based on input. This provides a better source of truth rather than - computing in a static graph where factors such as number of tokens, - cache size, and padding may affect accurate, efficient static - computation of the position indices. - - Positions should be the same shape as input_ids where each value - is the corresponding integer position of the token id in the overall - sequence. Padding tokens are not counted towards positions and - should be inputted as 0. - - When running a model for a single input id with `n` previously - processed tokens (prompt seq len + number of tokens generated already) - - This transform will replace the input to the position embeddings gather - by an explicit onnx graph input. Will delete any operations - used to compute the positions that are now no longer used in the - graph. Optionally keeps an offset `Add` node that is unique to the - OPT graph. - - Transforms - ``` - | Graph computed positions - | | - | Add (Optional) - | | - | Gather(model.decoder.embed_positions.weight) - ``` - Into - ``` - | Explicit Graph input (deletes now orphaned nodes to compute positions) - | | - | Add (Optional) - | | - | Gather(model.decoder.embed_positions.weight) - ``` - - """ - - POSITIONS_NAME = "positions" # matches intermediate var name in torch - - def transform(self, model: ModelProto) -> ModelProto: - model = self.add_positions_input(model) - position_embeddings_node = self.find_embed_positions_gather_node(model) - model = self._update_position_embeddings_for_graph_input( - model, position_embeddings_node - ) - return model - - @classmethod - def add_positions_input(cls, model: ModelProto) -> ModelProto: - """ - Adds positions as an input to the model - - :param model: model to update - :return: updated model - """ - # positions tensor should have shape equal to input_ids - input_ids_info = [ - input_info - for input_info in model.graph.input - if input_info.name == "input_ids" - ][0] - if not input_ids_info: - raise RuntimeError( - f"{cls.__name__} - unable to find 'input_ids' in model input" - ) - - positions_input_info = deepcopy(input_ids_info) - positions_input_info.name = cls.POSITIONS_NAME - model.graph.input.append(positions_input_info) - return model - - @classmethod - def find_embed_positions_gather_node(cls, model: ModelProto) -> NodeProto: - for node in model.graph.node: - if node.op_type != "Gather": - continue - if node.input[0] == _EMBED_POSITIONS_ID: - # found the embed_positions_gather_node - return node - raise RuntimeError( - f"Unable to find position embeddings gather node with id " - f"{_EMBED_POSITIONS_ID} in {cls.__name__}" - ) - - def _update_position_embeddings_for_graph_input( - self, model: ModelProto, position_embeddings_node: NodeProto - ) -> ModelProto: - graph = ONNXGraph(model) - - # select target node to update as positions input - position_embeddings_parent = graph.get_node_single_parent( - position_embeddings_node, index=1 - ) - - if not isinstance(position_embeddings_parent, NodeProto): - raise RuntimeError( - f"Unable to find input to position embeddings node: " - f"{position_embeddings_node.name} as a node in the given model" - ) - - if position_embeddings_parent.op_type == "Add": - # OPT has a special Add offset for position ids, allow this - # to be where positions are fed instead - target_update_node = position_embeddings_parent - target_input_idx = 0 # assume positions are first input to the Add - else: - target_update_node = position_embeddings_node - target_input_idx = 1 # gather idxs - - # reroute target node input to the positions graph input - old_positions_input = target_update_node.input[target_input_idx] - target_update_node.input[target_input_idx] = self.POSITIONS_NAME - graph.update() - self.log_match(target_update_node) - - # traverse graph upwards to delete any nodes that are now orphaned - nodes_to_delete = [] - # start queue with previous positions input node - queue = [graph.get_node_by_output_id(old_positions_input)] - seen_node_names = {queue[0].name} - while queue: - current_node = queue.pop(0) - if not isinstance(current_node, NodeProto): - continue - - node_children = graph.get_node_children(current_node) - if any(child not in nodes_to_delete for child in node_children): - # node has a child that is not on the orphaned branch - # do not remove and do not traverse further - continue - else: - nodes_to_delete.append(current_node) - self.log_match(current_node) - for parent in graph.get_node_parents(current_node): - if isinstance(parent, NodeProto) and ( - parent.name not in seen_node_names - ): - seen_node_names.add(parent.name) - queue.append(parent) - - graph.delete_nodes(nodes_to_delete) - graph.update() - graph.delete_unused_initializers() - - return model From b271da77b71a5f9e0b7bb514d6e3f0cfd568fda7 Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Tue, 20 Jun 2023 13:39:29 +0200 Subject: [PATCH 4/5] fold into the ONNXGraph & apply fixes after testing injection with kwargs --- src/sparseml/exporters/kv_cache_injector.py | 3 +- .../kv_cache/positions_adjustment_codegen.py | 4 +- .../kv_cache/positions_adjustment_opt.py | 6 +- src/sparseml/onnx/utils/graph_editor.py | 110 +++++++++--------- 4 files changed, 59 insertions(+), 64 deletions(-) diff --git a/src/sparseml/exporters/kv_cache_injector.py b/src/sparseml/exporters/kv_cache_injector.py index e088aeb0e7c..1adf2b14276 100644 --- a/src/sparseml/exporters/kv_cache_injector.py +++ b/src/sparseml/exporters/kv_cache_injector.py @@ -93,7 +93,6 @@ def __init__( if no `model_path` is provided. """ self.inplace = inplace - self.config = get_kv_cache_config(model_path) if model_path is not None: # get the parameters from the config @@ -109,7 +108,7 @@ def __init__( positions_adjustment = self.config.positions_adjustment_transform elif kwargs: - # get the parameters from the kwargs + _LOGGER.info("Configuration for KV cache injection provided via kwargs") num_attention_heads = kwargs.get("num_attention_heads") hidden_size_kv_cache_dim = kwargs.get("hidden_size_kv_cache_dim") multiply_batch_by_num_att_heads = kwargs.get( diff --git a/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_codegen.py b/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_codegen.py index 6027b3dd270..7041a36f2c2 100644 --- a/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_codegen.py +++ b/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_codegen.py @@ -18,7 +18,7 @@ PositionsAdjustmentBase, ) from sparseml.exporters.transforms.utils.matching import get_structural_matches -from sparseml.onnx.utils import ONNXGraph, find_orphaned_nodes +from sparseml.onnx.utils import ONNXGraph __all__ = ["PositionsAdjustmentCodeGen"] @@ -80,7 +80,7 @@ def _update_position_embeddings_for_graph_input( if input_name == output_to_replace: graph.update_node_input(child_node, self.POSITIONS_NAME, idx) - orphaned_nodes = find_orphaned_nodes(model, node) + orphaned_nodes = graph.find_orphaned_nodes(node) [self.log_match(node) for node in orphaned_nodes] graph.delete_nodes(orphaned_nodes) graph.update() diff --git a/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_opt.py b/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_opt.py index c35fb7efca8..25d6ccc2c18 100644 --- a/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_opt.py +++ b/src/sparseml/exporters/transforms/kv_cache/positions_adjustment_opt.py @@ -17,7 +17,7 @@ from sparseml.exporters.transforms.kv_cache.positions_adjustment_base import ( PositionsAdjustmentBase, ) -from sparseml.onnx.utils import ONNXGraph, find_orphaned_nodes +from sparseml.onnx.utils import ONNXGraph __all__ = ["PositionsAdjustmentOPT"] @@ -121,8 +121,8 @@ def _update_position_embeddings_for_graph_input( graph.update() self.log_match(target_update_node) - nodes_to_delete = find_orphaned_nodes( - model, graph.get_node_by_output_id(old_positions_input) + nodes_to_delete = graph.find_orphaned_nodes( + graph.get_node_by_output_id(old_positions_input) ) [self.log_match(node) for node in nodes_to_delete] diff --git a/src/sparseml/onnx/utils/graph_editor.py b/src/sparseml/onnx/utils/graph_editor.py index 72cd832b5b7..1b0c2d23a12 100644 --- a/src/sparseml/onnx/utils/graph_editor.py +++ b/src/sparseml/onnx/utils/graph_editor.py @@ -24,7 +24,7 @@ from onnx import ModelProto, NodeProto, TensorProto, numpy_helper from toposort import toposort_flatten -from sparseml.onnx.utils.helpers import get_node_output_nodes, get_node_params +from sparseml.onnx.utils.helpers import get_node_params __all__ = [ @@ -36,7 +36,6 @@ "prune_unstructured", "prune_model_one_shot", "prune_model_one_shot_iter", - "find_orphaned_nodes", ] @@ -243,6 +242,58 @@ def delete_unused_initializers(self): ] ) # delete inits that have no edge + def find_orphaned_nodes(self, node: NodeProto) -> List[NodeProto]: + """ + Given a node, that is to be removed from the graph, find all nodes that + will be orphaned as a result of the removal. Orphaned nodes are nodes + that will have no inputs after the removal of the given node. + The method traverses the graph upwards from the given node until + a node with multiple outputs is found. All nodes that are traversed + are considered orphaned and will be removed. + + :param node: The node to remove + :return: A tuple of the model and a list of orphaned nodes + """ + nodes_to_delete = [node] + # start queue with previous positions input node + queue = [node] + while queue: + current_node = queue.pop(0) + if not isinstance(current_node, NodeProto): + continue + node_parents = self.get_node_parents(current_node) + # if node parent has only one output (current child) + # than it is orphaned and will be removed. + # continue traversing the graph upwards until + # a node with output that is not current child is found + for parent in node_parents: + + if not isinstance(parent, NodeProto): + # if parent is not a node, it is a graph input + # and should not be removed + continue + elif parent.op_type == "Constant": + # if constant node is found, + # automatically remove it and continue traversing + nodes_to_delete.append(parent) + + parent_output_node_names = set( + n.name for n in self.get_node_parents(node=parent) + ) + if len(parent_output_node_names) == 1: + # if parent has only one output, it is orphaned + queue.append(parent) + nodes_to_delete.append(parent) + elif not parent_output_node_names.difference( + set(n.name for n in nodes_to_delete) + ): + # if parent has multiple outputs, but they are all already in the + # nodes_to_delete list, it is orphaned + queue.append(parent) + nodes_to_delete.append(parent) + + return nodes_to_delete + def sort_nodes_topologically(self): """ Sorts the order of the graph Node repeated field in place in topological @@ -451,58 +502,3 @@ def prune_model_one_shot_iter( pruned_weight_val = prune_unstructured(weight.val, sparsity) update_model_param(model, weight.name, pruned_weight_val) yield (index + 1) / len(nodes) - - -def find_orphaned_nodes(model: ModelProto, node: NodeProto) -> List[NodeProto]: - """ - Given a node, that is to be removed from the graph, find all nodes that - will be orphaned as a result of the removal. Orphaned nodes are nodes - that will have no inputs after the removal of the given node. - The method traverses the graph upwards from the given node until - a node with multiple outputs is found. All nodes that are traversed - are considered orphaned and will be removed. - - :param model: The model that the node belongs to - :param node: The node to remove - :return: A tuple of the model and a list of orphaned nodes - """ - graph = ONNXGraph(model) - nodes_to_delete = [node] - # start queue with previous positions input node - queue = [node] - while queue: - current_node = queue.pop(0) - if not isinstance(current_node, NodeProto): - continue - node_parents = graph.get_node_parents(current_node) - # if node parent has only one output (current child) - # than it is orphaned and will be removed. - # continue traversing the graph upwards until - # a node with output that is not current child is found - for parent in node_parents: - - if not isinstance(parent, NodeProto): - # if parent is not a node, it is a graph input - # and should not be removed - continue - elif parent.op_type == "Constant": - # if constant node is found, - # automatically remove it and continue traversing - nodes_to_delete.append(parent) - - parent_output_node_names = set( - n.name for n in get_node_output_nodes(model=model, node=parent) - ) - if len(parent_output_node_names) == 1: - # if parent has only one output, it is orphaned - queue.append(parent) - nodes_to_delete.append(parent) - elif not parent_output_node_names.difference( - set(n.name for n in nodes_to_delete) - ): - # if parent has multiple outputs, but they are all already in the - # nodes_to_delete list, it is orphaned - queue.append(parent) - nodes_to_delete.append(parent) - - return nodes_to_delete From 836ee7ebb89caa450146f477c20fbaecdf713f55 Mon Sep 17 00:00:00 2001 From: Damian Date: Fri, 23 Jun 2023 11:37:25 +0000 Subject: [PATCH 5/5] add _get_transforms_from_kwargs and _get_transforms_from_config --- src/sparseml/exporters/kv_cache_injector.py | 73 ++++++++++--------- .../exporters/transforms/kv_cache/configs.py | 2 +- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/src/sparseml/exporters/kv_cache_injector.py b/src/sparseml/exporters/kv_cache_injector.py index e088aeb0e7c..7b4a4bb16b1 100644 --- a/src/sparseml/exporters/kv_cache_injector.py +++ b/src/sparseml/exporters/kv_cache_injector.py @@ -15,13 +15,15 @@ import logging from copy import deepcopy from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Union import onnx from sparseml.exporters.base_exporter import BaseExporter +from sparseml.exporters.transforms import OnnxTransform from sparseml.exporters.transforms.kv_cache import ( CacheKeysAndValues, + KeyValueCacheConfig, get_kv_cache_config, ) from sparsezoo.utils import save_onnx @@ -93,31 +95,14 @@ def __init__( if no `model_path` is provided. """ self.inplace = inplace - self.config = get_kv_cache_config(model_path) - if model_path is not None: - # get the parameters from the config - self.config = get_kv_cache_config(model_path) + config = get_kv_cache_config(model_path) - num_attention_heads = self.config.num_attention_heads - hidden_size_kv_cache_dim = self.config.hidden_size_kv_cache - multiply_batch_by_num_att_heads = ( - self.config.multiply_batch_by_num_att_heads - ) - transpose_value_input = self.config.transpose_value_input - transpose_key_input = self.config.transpose_key_input - positions_adjustment = self.config.positions_adjustment_transform + if config is not None: + transforms = self._get_transforms_from_config(config) elif kwargs: - # get the parameters from the kwargs - num_attention_heads = kwargs.get("num_attention_heads") - hidden_size_kv_cache_dim = kwargs.get("hidden_size_kv_cache_dim") - multiply_batch_by_num_att_heads = kwargs.get( - "multiply_batch_by_num_att_heads", False - ) - transpose_value_input = kwargs.get("transpose_value_input") - transpose_key_input = kwargs.get("transpose_key_input") - positions_adjustment = None + transforms = self._get_transforms_from_kwargs(kwargs) else: raise ValueError( @@ -125,17 +110,6 @@ def __init__( "KeyValueCacheInjector" ) - transforms = [ - CacheKeysAndValues( - num_attention_heads=num_attention_heads, - hidden_size_kv_cache=hidden_size_kv_cache_dim, - multiply_batch_by_num_att_heads=multiply_batch_by_num_att_heads, - transpose_value_input=transpose_value_input, - transpose_key_input=transpose_key_input, - ) - ] - if positions_adjustment is not None: - transforms += [positions_adjustment()] super().__init__(transforms) def pre_validate(self, model: Union[onnx.ModelProto, str, Path]) -> onnx.ModelProto: @@ -154,3 +128,36 @@ def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto: def export(self, pre_transforms_model: onnx.ModelProto, file_path: str): post_transforms_model: onnx.ModelProto = self.apply(pre_transforms_model) save_onnx(post_transforms_model, file_path) + + @staticmethod + def _get_transforms_from_config(config: KeyValueCacheConfig) -> List[OnnxTransform]: + positions_adjustment = config.positions_adjustment_transform + + transforms = [ + CacheKeysAndValues( + num_attention_heads=config.num_attention_heads, + hidden_size_kv_cache=config.hidden_size_kv_cache, + multiply_batch_by_num_att_heads=config.multiply_batch_by_num_att_heads, + transpose_value_input=config.transpose_value_input, + transpose_key_input=config.transpose_key_input, + ) + ] + if positions_adjustment is not None: + transforms += [positions_adjustment()] + + return transforms + + @staticmethod + def _get_transforms_from_kwargs(kwargs: Dict[str, Any]) -> List[OnnxTransform]: + transforms = [ + CacheKeysAndValues( + num_attention_heads=kwargs.get("num_attention_heads"), + hidden_size_kv_cache=kwargs.get("hidden_size_kv_cache"), + multiply_batch_by_num_att_heads=kwargs.get( + "multiply_batch_by_num_att_heads", False + ), + transpose_value_input=kwargs.get("transpose_value_input", None), + transpose_key_input=kwargs.get("transpose_key_input", None), + ) + ] + return transforms diff --git a/src/sparseml/exporters/transforms/kv_cache/configs.py b/src/sparseml/exporters/transforms/kv_cache/configs.py index 9d3643041c3..6473d76afe2 100644 --- a/src/sparseml/exporters/transforms/kv_cache/configs.py +++ b/src/sparseml/exporters/transforms/kv_cache/configs.py @@ -29,7 +29,7 @@ _LOGGER = logging.getLogger(__name__) -__all__ = ["get_kv_cache_config"] +__all__ = ["get_kv_cache_config", "KeyValueCacheConfig"] class KeyValueCacheConfig(BaseModel):