Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kvcache support to debug_analysis.py and engine.py #1132

Merged
merged 6 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 60 additions & 10 deletions src/deepsparse/debug_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@
import json
import os

from deepsparse import model_debug_analysis
from deepsparse import KVCacheParams, model_debug_analysis
from deepsparse.utils import (
default_cached_outputs,
generate_random_inputs,
model_to_path,
override_onnx_input_shapes,
Expand Down Expand Up @@ -140,6 +141,29 @@ def parse_args():
type=str,
default="",
)
parser.add_argument(
"--disable-batch-override",
help="Ignores the batch_size parameter",
action="store_true",
default=False,
)
parser.add_argument(
"--use-kvcache", help="Enable KVCache", action="store_true", default=False
)
parser.add_argument(
"--kv-cache-prev-num-tokens",
help="KVCache: The amount of previous tokens that will be read"
" from the external KV cache on the first inference",
type=int,
default=None,
)
parser.add_argument(
"--kv-cache-num-frozen-tokens",
help="KVCache: The amount of first tokens that we want to keep"
" permanently in the KV cache",
alexm-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved
type=int,
default=None,
)
parser.add_argument(
"-q",
"--quiet",
Expand Down Expand Up @@ -186,14 +210,10 @@ def construct_layer_table(result):
"{: >#08.4f} | {: >#08.4f} | {: >#08.4f} | {:12}"
)
for li in result["layer_info"]:
table_str += layer_info_to_string(
li,
"{:28}| " + info_format_base + "\n",
)
table_str += layer_info_to_string(li, "{:28}| " + info_format_base + "\n")
for sub_li in li["sub_layer_info"]:
table_str += layer_info_to_string(
sub_li,
" {:26}| " + info_format_base + "\n",
sub_li, " {:26}| " + info_format_base + "\n"
)

table_str += "Total Time(MS): {:05f}\n".format(result["average_total_time"])
Expand Down Expand Up @@ -295,11 +315,39 @@ def main():

print("Analyzing model: {}".format(orig_model_path))

batch_size = args.batch_size
if args.disable_batch_override:
batch_size = None
os.environ["NM_DISABLE_BATCH_OVERRIDE"] = "1"
print("Disable batch override: ON")

if input_shapes:
with override_onnx_input_shapes(model_path, input_shapes) as tmp_path:
input_list = generate_random_inputs(tmp_path, args.batch_size)
input_list = generate_random_inputs(tmp_path, batch_size)
else:
input_list = generate_random_inputs(model_path, args.batch_size)
input_list = generate_random_inputs(model_path, batch_size)

kv_cache_params = None
if args.use_kvcache:
kv_cache_prev_num_tokens = 0
if args.kv_cache_prev_num_tokens is not None:
kv_cache_prev_num_tokens = args.kv_cache_prev_num_tokens

kv_cache_num_frozen_tokens = 0
if args.kv_cache_num_frozen_tokens is not None:
kv_cache_num_frozen_tokens = args.kv_cache_num_frozen_tokens

kv_cache_params = KVCacheParams(
default_cached_outputs(model_path),
kv_cache_prev_num_tokens,
kv_cache_num_frozen_tokens,
)

print(
"Enable KVCache: prev_num_tokens = {}, num_frozen_tokens = {}".format(
kv_cache_params.prev_num_tokens, kv_cache_params.num_frozen_tokens
)
)

result = model_debug_analysis(
model_path,
Expand All @@ -308,9 +356,11 @@ def main():
num_cores=args.num_cores,
num_iterations=args.num_iterations,
num_warmup_iterations=args.num_warmup_iterations,
optimization_level=args.optimization,
optimization_level=int(args.optimization),
disable_batch_override=args.disable_batch_override,
imposed_ks=imposed_kernel_sparsity,
input_shapes=input_shapes,
kv_cache_params=kv_cache_params,
)

if not args.quiet:
Expand Down
Loading
Loading