From 0ae6aa7a61cbca11e2b602c9aac5f4ff967d6406 Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Fri, 21 Jul 2023 18:58:22 -0400 Subject: [PATCH] Add kvcache support to debug_analysis.py and engine.py (#1132) * add kvcache support to debug_analysis.py and engine.py * remove print * make defaults for kvcache run * review comments * Update src/deepsparse/utils/onnx.py --------- Co-authored-by: Michael Goin --- src/deepsparse/debug_analysis.py | 70 ++++++++++++-- src/deepsparse/engine.py | 159 ++++++++++++++++++++----------- src/deepsparse/utils/onnx.py | 20 ++++ 3 files changed, 186 insertions(+), 63 deletions(-) diff --git a/src/deepsparse/debug_analysis.py b/src/deepsparse/debug_analysis.py index 062e27ea50..e2786b4f21 100644 --- a/src/deepsparse/debug_analysis.py +++ b/src/deepsparse/debug_analysis.py @@ -62,8 +62,9 @@ import json import os -from deepsparse import model_debug_analysis +from deepsparse import KVCacheParams, model_debug_analysis from deepsparse.utils import ( + default_cached_outputs, generate_random_inputs, model_to_path, override_onnx_input_shapes, @@ -140,6 +141,29 @@ def parse_args(): type=str, default="", ) + parser.add_argument( + "--disable-batch-override", + help="Ignores the batch_size parameter", + action="store_true", + default=False, + ) + parser.add_argument( + "--use-kvcache", help="Enable KVCache", action="store_true", default=False + ) + parser.add_argument( + "--kv-cache-prev-num-tokens", + help="KVCache: The amount of previous tokens that will be read" + " from the external KV cache on the first inference", + type=int, + default=None, + ) + parser.add_argument( + "--kv-cache-num-frozen-tokens", + help="KVCache: The amount of first tokens that we want to keep" + " permanently in the KV cache", + type=int, + default=None, + ) parser.add_argument( "-q", "--quiet", @@ -186,14 +210,10 @@ def construct_layer_table(result): "{: >#08.4f} | {: >#08.4f} | {: >#08.4f} | {:12}" ) for li in result["layer_info"]: - table_str += layer_info_to_string( - li, - "{:28}| " + info_format_base + "\n", - ) + table_str += layer_info_to_string(li, "{:28}| " + info_format_base + "\n") for sub_li in li["sub_layer_info"]: table_str += layer_info_to_string( - sub_li, - " {:26}| " + info_format_base + "\n", + sub_li, " {:26}| " + info_format_base + "\n" ) table_str += "Total Time(MS): {:05f}\n".format(result["average_total_time"]) @@ -295,11 +315,39 @@ def main(): print("Analyzing model: {}".format(orig_model_path)) + batch_size = args.batch_size + if args.disable_batch_override: + batch_size = None + os.environ["NM_DISABLE_BATCH_OVERRIDE"] = "1" + print("Disable batch override: ON") + if input_shapes: with override_onnx_input_shapes(model_path, input_shapes) as tmp_path: - input_list = generate_random_inputs(tmp_path, args.batch_size) + input_list = generate_random_inputs(tmp_path, batch_size) else: - input_list = generate_random_inputs(model_path, args.batch_size) + input_list = generate_random_inputs(model_path, batch_size) + + kv_cache_params = None + if args.use_kvcache: + kv_cache_prev_num_tokens = 0 + if args.kv_cache_prev_num_tokens is not None: + kv_cache_prev_num_tokens = args.kv_cache_prev_num_tokens + + kv_cache_num_frozen_tokens = 0 + if args.kv_cache_num_frozen_tokens is not None: + kv_cache_num_frozen_tokens = args.kv_cache_num_frozen_tokens + + kv_cache_params = KVCacheParams( + default_cached_outputs(model_path), + kv_cache_prev_num_tokens, + kv_cache_num_frozen_tokens, + ) + + print( + "Enable KVCache: prev_num_tokens = {}, num_frozen_tokens = {}".format( + kv_cache_params.prev_num_tokens, kv_cache_params.num_frozen_tokens + ) + ) result = model_debug_analysis( model_path, @@ -308,9 +356,11 @@ def main(): num_cores=args.num_cores, num_iterations=args.num_iterations, num_warmup_iterations=args.num_warmup_iterations, - optimization_level=args.optimization, + optimization_level=int(args.optimization), + disable_batch_override=args.disable_batch_override, imposed_ks=imposed_kernel_sparsity, input_shapes=input_shapes, + kv_cache_params=kv_cache_params, ) if not args.quiet: diff --git a/src/deepsparse/engine.py b/src/deepsparse/engine.py index 0f5160299d..624b6f5cf5 100644 --- a/src/deepsparse/engine.py +++ b/src/deepsparse/engine.py @@ -22,6 +22,7 @@ from typing import Dict, Iterable, List, Optional, Tuple, Union import numpy +import onnx from tqdm.auto import tqdm from deepsparse.analytics import deepsparse_analytics as _analytics @@ -56,6 +57,7 @@ "Context", "MultiModelEngine", "BaseEngine", + "KVCacheParams", ] _LOGGER = logging.getLogger(__name__) @@ -169,11 +171,7 @@ class Context(object): concurrently. """ - def __init__( - self, - num_cores: int = None, - num_streams: int = None, - ): + def __init__(self, num_cores: int = None, num_streams: int = None): self._num_cores = _validate_num_cores(num_cores) self._scheduler = Scheduler.from_str("elastic") self._deepsparse_context = LIB.deepsparse_context( @@ -205,6 +203,24 @@ def __repr__(self) -> str: return f"Context(num_cores={self.num_cores}, num_streams={self.num_streams}, scheduler={self.scheduler})" +class KVCacheParams: + """ + :param cached_outputs: A list of bools that indicates for each output + whether it is cached or not + :param prev_num_tokens: The amount of previous tokens that will be read + from the external KV cache on the first inference + :param num_frozen_tokens: The amount of first tokens that we want to keep + permanently in the KV cache. + """ + + def __init__( + self, cached_outputs: List[bool], prev_num_tokens: int, num_frozen_tokens: int + ): + self.cached_outputs = cached_outputs + self.prev_num_tokens = prev_num_tokens + self.num_frozen_tokens = num_frozen_tokens + + class BaseEngine(object): def construct( self, @@ -214,6 +230,8 @@ def construct( num_streams: int = None, scheduler: Scheduler = None, input_shapes: List[List[int]] = None, + disable_batch_override: bool = False, + kv_cache_params: Optional[KVCacheParams] = None, ): _analytics.send_event("python__engine__init") self._model_path = model_to_path(model) @@ -222,6 +240,8 @@ def construct( self._num_streams = _validate_num_streams(num_streams, self._num_cores) self._scheduler = _validate_scheduler(scheduler) self._input_shapes = input_shapes + self._disable_batch_override = disable_batch_override + self._kv_cache_params = kv_cache_params self._cpu_avx_type = AVX_TYPE self._cpu_vnni = VNNI @@ -231,6 +251,8 @@ def construct_with_context( batch_size: int, context: Context, input_shapes: List[List[int]] = None, + disable_batch_override: bool = False, + kv_cache_params: Optional[KVCacheParams] = None, ): _analytics.send_event("python__engine__init") self._model_path = model_to_path(model) @@ -239,6 +261,8 @@ def construct_with_context( self._num_streams = context.num_streams self._scheduler = _validate_scheduler(context.scheduler) self._input_shapes = input_shapes + self._disable_batch_override = disable_batch_override + self._kv_cache_params = kv_cache_params self._cpu_avx_type = AVX_TYPE self._cpu_vnni = VNNI @@ -304,9 +328,7 @@ def __init__( ) def __call__( - self, - inp: List[numpy.ndarray], - val_inp: bool = True, + self, inp: List[numpy.ndarray], val_inp: bool = True ) -> List[numpy.ndarray]: """ Convenience function for Engine.run(), see @run for more details @@ -461,9 +483,7 @@ def batched_run( return join_engine_outputs(batch_outputs, orig_batch_size) def run( - self, - inp: List[numpy.ndarray], - val_inp: bool = True, + self, inp: List[numpy.ndarray], val_inp: bool = True ) -> List[numpy.ndarray]: """ Run given inputs through the model for inference. @@ -530,9 +550,7 @@ def timed_run( return out, end - start def mapped_run( - self, - inp: List[numpy.ndarray], - val_inp: bool = True, + self, inp: List[numpy.ndarray], val_inp: bool = True ) -> Dict[str, numpy.ndarray]: """ Run given inputs through the model for inference. @@ -681,13 +699,14 @@ def _validate_inputs(self, inp: List[numpy.ndarray]): raise ValueError("inp must be a list, given {}".format(type(inp))) for arr in inp: - if arr.shape[0] != self._batch_size: - raise ValueError( - ( - "array batch size of {} must match the batch size " - "the model was instantiated with {}" - ).format(arr.shape[0], self._batch_size) - ) + if not self._disable_batch_override: + if arr.shape[0] != self._batch_size: + raise ValueError( + ( + "array batch size of {} must match the batch size " + "the model was instantiated with {}" + ).format(arr.shape[0], self._batch_size) + ) if not arr.flags["C_CONTIGUOUS"]: raise ValueError( @@ -719,8 +738,6 @@ class DebugAnalysisEngine(Engine): :param num_cores: The number of physical cores to run the model on. If more cores are requested than are available on a single socket, the engine will try to distribute them evenly across as few sockets as possible. - :param num_streams: The max number of requests the model can handle - concurrently. :param scheduler: The kind of scheduler to execute with. Pass None for the default. :param input_shapes: The list of shapes to set the inputs to. Pass None to use model as-is. :param num_iterations: The number of iterations to run benchmarking for. @@ -729,12 +746,17 @@ class DebugAnalysisEngine(Engine): benchmarking. These executions will not be counted in the benchmark results that are returned. Useful and recommended to bring the system to a steady state. Default is 5 - :param include_inputs: If True, inputs from forward passes during benchmarking - will be added to the results. Default is False - :param include_outputs: If True, outputs from forward passes during benchmarking - will be added to the results. Default is False - :param show_progress: If True, will display a progress bar. Default is False - :param scheduler: The kind of scheduler to execute with. Pass None for the default. + :param optimization_level: The amount of graph optimizations to perform. + The current choices are either 0 (minimal) or 1 (all), default is 1 + :param imposed_as: Imposed activation sparsity, defaults to None. + Will force the activation sparsity from all ReLu layers in the graph + to match this desired sparsity level (percentage of 0's in the tensor). + Beneficial for seeing how AS affects the performance of the model. + :param imposed_ks: Imposed kernel sparsity, defaults to None. + Will force all prunable layers in the graph to have weights with + this desired sparsity level (percentage of 0's in the tensor). + Beneficial for seeing how pruning affects the performance of the model. + :param kv_cache_params: KV cache execution params, defaults to None. """ def __init__( @@ -747,17 +769,49 @@ def __init__( num_iterations: int = 20, num_warmup_iterations: int = 5, optimization_level: int = 1, + disable_batch_override: bool = False, imposed_as: Optional[float] = None, imposed_ks: Optional[float] = None, + kv_cache_params: Optional[KVCacheParams] = None, ): BaseEngine.construct( - self, model, batch_size, num_cores, None, scheduler, input_shapes + self, + model, + batch_size, + num_cores, + None, + scheduler, + input_shapes, + disable_batch_override, + kv_cache_params, ) - if self._input_shapes: - with override_onnx_input_shapes( - self._model_path, self._input_shapes - ) as model_path: + # Helper + def make_engine(self, model_path): + if self._kv_cache_params: + self._kv_cache = LIB.kv_cache( + self._kv_cache_params.prev_num_tokens, + self._kv_cache_params.num_frozen_tokens, + ) + + self._eng_net = LIB.deepsparse_engine( + model_path, + self._batch_size, + self._num_cores, + self._num_streams, + self._scheduler.value, + None, + self._kv_cache_params.cached_outputs, + "external", + num_iterations, + num_warmup_iterations, + optimization_level, + imposed_as, + imposed_ks, + ) + else: + self._kv_cache = None + self._eng_net = LIB.deepsparse_engine( model_path, self._batch_size, @@ -772,26 +826,17 @@ def __init__( imposed_as, imposed_ks, ) + + if self._input_shapes: + with override_onnx_input_shapes( + self._model_path, self._input_shapes + ) as model_path: + make_engine(self, model_path) else: - self._eng_net = LIB.deepsparse_engine( - self._model_path, - self._batch_size, - self._num_cores, - self._num_streams, - self._scheduler.value, - None, - "external", - num_iterations, - num_warmup_iterations, - optimization_level, - imposed_as, - imposed_ks, - ) + make_engine(self, self._model_path) def analyze( - self, - inp: List[numpy.ndarray], - val_inp: bool = True, + self, inp: List[numpy.ndarray], val_inp: bool = True ) -> List[numpy.ndarray]: """ Function to analyze a model's performance in the DeepSparse Engine. @@ -804,10 +849,11 @@ def analyze( are setup correctly for the DeepSparse Engine :return: the analysis structure containing the performance details of each layer """ + if val_inp: self._validate_inputs(inp) - [out, bench_info] = self._eng_net.benchmark_execute(inp) + [_, bench_info] = self._eng_net.benchmark_execute(inp, self._kv_cache) return bench_info @@ -978,10 +1024,12 @@ def model_debug_analysis( num_iterations: int = 20, num_warmup_iterations: int = 5, optimization_level: int = 1, + disable_batch_override: bool = False, imposed_as: Optional[float] = None, imposed_ks: Optional[float] = None, scheduler: Scheduler = None, input_shapes: List[List[int]] = None, + kv_cache_params: Optional[KVCacheParams] = None, ) -> dict: """ Function to analyze a model's performance in the DeepSparse Engine. @@ -1006,6 +1054,7 @@ def model_debug_analysis( before analyzing, default is 5 :param optimization_level: The amount of graph optimizations to perform. The current choices are either 0 (minimal) or 1 (all), default is 1 + :param disable_batch_override: Indicates whether disable_batch_override was used or not :param imposed_as: Imposed activation sparsity, defaults to None. Will force the activation sparsity from all ReLu layers in the graph to match this desired sparsity level (percentage of 0's in the tensor). @@ -1015,7 +1064,9 @@ def model_debug_analysis( this desired sparsity level (percentage of 0's in the tensor). Beneficial for seeing how pruning affects the performance of the model. :param scheduler: The kind of scheduler to execute with. Pass None for the default. - :return: the analysis structure containing the performance details of each layer + :param input_shapes: Overrides input shapes, default to None (no override). + :param kv_cache_params: KV cache execution params, defaults to None. + :return: The analysis structure containing the performance details of each layer """ model = DebugAnalysisEngine( model=model, @@ -1026,8 +1077,10 @@ def model_debug_analysis( num_iterations=num_iterations, num_warmup_iterations=num_warmup_iterations, optimization_level=optimization_level, + disable_batch_override=disable_batch_override, imposed_as=imposed_as, imposed_ks=imposed_ks, + kv_cache_params=kv_cache_params, ) return model.analyze(inp) diff --git a/src/deepsparse/utils/onnx.py b/src/deepsparse/utils/onnx.py index c34eda9b0e..89d8baf4c9 100644 --- a/src/deepsparse/utils/onnx.py +++ b/src/deepsparse/utils/onnx.py @@ -49,6 +49,7 @@ "override_onnx_input_shapes", "truncate_onnx_model", "truncate_onnx_embedding_model", + "default_cached_outputs", ] _LOGGER = logging.getLogger(__name__) @@ -472,3 +473,22 @@ def truncate_onnx_embedding_model( ) return output_filepath, tmp_file + + +def default_cached_outputs(model_path: str) -> List[bool]: + """ + :param model_path: Path to a model + :return A list of bools that indicates caching of all outputs except the first one. + """ + + outputs = list(onnx.load(model_path).graph.output) + assert len(outputs) > 0 + + # Create a boolean list of every output of the + # model [logits, key0, value0, key1, value1, ..., keyN, valueN] + cached_outputs = [True for i in range(len(outputs))] + + # Assume first input is logits and logits ought not to be cached + cached_outputs[0] = False + + return cached_outputs