Skip to content

Commit

Permalink
Merge branch 'feature/damian/kv_cache_codegen' of https://github.com/…
Browse files Browse the repository at this point in the history
…neuralmagic/sparseml into feature/damian/kv_cache_codegen
  • Loading branch information
dbogunowicz committed Jun 23, 2023
2 parents 836ee7e + a819c27 commit bba3924
Show file tree
Hide file tree
Showing 22 changed files with 186 additions and 276 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
PositionsAdjustmentBase,
)
from sparseml.exporters.transforms.utils.matching import get_structural_matches
from sparseml.onnx.utils import ONNXGraph, find_orphaned_nodes
from sparseml.onnx.utils import ONNXGraph


__all__ = ["PositionsAdjustmentCodeGen"]
Expand Down Expand Up @@ -80,7 +80,7 @@ def _update_position_embeddings_for_graph_input(
if input_name == output_to_replace:
graph.update_node_input(child_node, self.POSITIONS_NAME, idx)

orphaned_nodes = find_orphaned_nodes(model, node)
orphaned_nodes = graph.find_orphaned_nodes(node)
[self.log_match(node) for node in orphaned_nodes]
graph.delete_nodes(orphaned_nodes)
graph.update()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sparseml.exporters.transforms.kv_cache.positions_adjustment_base import (
PositionsAdjustmentBase,
)
from sparseml.onnx.utils import ONNXGraph, find_orphaned_nodes
from sparseml.onnx.utils import ONNXGraph


__all__ = ["PositionsAdjustmentOPT"]
Expand Down Expand Up @@ -121,8 +121,8 @@ def _update_position_embeddings_for_graph_input(
graph.update()
self.log_match(target_update_node)

nodes_to_delete = find_orphaned_nodes(
model, graph.get_node_by_output_id(old_positions_input)
nodes_to_delete = graph.find_orphaned_nodes(
graph.get_node_by_output_id(old_positions_input)
)
[self.log_match(node) for node in nodes_to_delete]

Expand Down
26 changes: 16 additions & 10 deletions src/sparseml/onnx/optim/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
"""


import logging
import os
import tempfile
from typing import Dict, Generator, Iterable, List, Tuple, Union

import numpy as np
Expand All @@ -28,6 +28,8 @@
from sparsezoo.utils import save_onnx, validate_onnx


_LOGGER = logging.getLogger(__name__)

__all__ = ["CalibrationSession"]


Expand Down Expand Up @@ -65,11 +67,11 @@ def __init__(
self._model_augmented = self.generate_augmented_model()

if self._augmented_model_path is None:
self._augmented_model_tmp_file = tempfile.NamedTemporaryFile(
suffix=".onnx", delete=True
self._augmented_model_path = os.path.join(
os.getcwd(), "model_augmented.onnx"
)
self._augmented_model_path = self._augmented_model_tmp_file.name
save_onnx(self._model_augmented, self._augmented_model_path)
_LOGGER.debug(f"Created an augmented model at: {self._augmented_model_path}")

self._sessions = {} # batch_size -> session
self._quantization_thresholds = {} # Dict[node.name, Tuple(min_val, max_val)]
Expand Down Expand Up @@ -103,13 +105,15 @@ def _optimize_model(self) -> Union[str, None]:
# no optimization performed, skip the rest of this block
raise Exception()
validate_onnx(model_optimized) # should raise exception if broken
optimized_model_path = tempfile.NamedTemporaryFile(
suffix=".onnx", delete=False
)
save_onnx(model_optimized, optimized_model_path.name)
optimized_model_path = os.path.join(os.getcwd(), "model_optimized.onnx")
save_onnx(model_optimized, optimized_model_path)
self._model = model_optimized
print("Optimization successful")
return optimized_model_path.name
_LOGGER.debug(
"Optimization successful. "
"Created an optimized model at: "
f"{optimized_model_path}"
)
return optimized_model_path
except Exception as e:
print(e)
print(
Expand Down Expand Up @@ -352,3 +356,5 @@ def __del__(self):
"""
if self._optimized_model_path is not None:
os.remove(self._optimized_model_path)
if self._augmented_model_path is not None:
os.remove(self._augmented_model_path)
Loading

0 comments on commit bba3924

Please sign in to comment.