Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feature/damian/no_kv_…
Browse files Browse the repository at this point in the history
…cache
  • Loading branch information
dbogunowicz committed Dec 11, 2023
2 parents dcab3f9 + e2f2305 commit e0a9dee
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 168 deletions.
8 changes: 4 additions & 4 deletions src/deepsparse/evaluation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from src.deepsparse.evaluation.integrations import ( # noqa: F401
try_import_llm_evaluation_harness,
)
from src.deepsparse.evaluation.results import Result, save_evaluations
from src.deepsparse.evaluation.results import Result, save_result
from src.deepsparse.evaluation.utils import args_to_dict, get_save_path
from src.deepsparse.pipeline import DEEPSPARSE_ENGINE, ORT_ENGINE, TORCHSCRIPT_ENGINE

Expand Down Expand Up @@ -210,7 +210,7 @@ def main(
**integration_args,
)

_LOGGER.info(f"Evaluation done. Results:\n{result}")
_LOGGER.info(f"Evaluation done. Result:\n{result.formatted}")

save_path = get_save_path(
save_path=save_path,
Expand All @@ -219,8 +219,8 @@ def main(
)
if save_path:
_LOGGER.info(f"Saving the evaluation results to {save_path}")
save_evaluations(
evaluations=result.formatted,
save_result(
result=result,
save_path=save_path,
save_format=type_serialization,
)
Expand Down
18 changes: 17 additions & 1 deletion src/deepsparse/evaluation/integrations/llm_evaluation_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,29 @@ def integration_eval(
results_raw = evaluator.simple_evaluate(**evaluator_input.dict())

results = Result(
raw=dict(output=results_raw, input=evaluator_input),
raw=dict(output=results_raw, input=filter_evaluator_input(evaluator_input)),
formatted=format_raw_results(results_raw),
)

return results


def filter_evaluator_input(
evaluator_input: "EvaluatorInputSchema",
) -> Dict[str, Any]: # noqa: F821
"""
Filter the evaluator input to remove the model field.
The model field is a complex object that cannot be serialized.
:param evaluator_input: the evaluator input to filter
:return: the filtered evaluator input
"""
evaluator = evaluator_input.dict()
del evaluator["model"]

return evaluator


def format_raw_results(results: Dict[str, Any]) -> List[Evaluation]:
"""
Format the raw results from llm_evaluation_harness into a list of
Expand Down
59 changes: 16 additions & 43 deletions src/deepsparse/evaluation/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from collections import OrderedDict
from typing import Any, List, Optional

import yaml
Expand All @@ -28,7 +26,7 @@
"EvalSample",
"Evaluation",
"Result",
"save_evaluations",
"save_result",
]


Expand Down Expand Up @@ -63,68 +61,43 @@ class Evaluation(BaseModel):

class Result(BaseModel):
formatted: List[Evaluation] = Field(
description="Evaluation results represented in the unified, structured format"
description="Evaluation result represented in the unified, structured format"
)
raw: Any = Field(
description="Evaluation results represented in the raw format "
description="Evaluation result represented in the raw format "
"(characteristic for the specific evaluation integration)"
)

def __str__(self):
"""
The string representation of the Result object is
the formatted evaluation results serialized in JSON.
"""
return save_evaluations(self.formatted, save_format="json", save_path=None)


def save_evaluations(
evaluations: List[Evaluation], save_format: str = "json", save_path: str = None
def save_result(
result: Result,
save_path: str,
save_format: str = "json",
):
"""
Saves a list of Evaluation objects to a file in the specified format.
:param evaluations: List of Evaluation objects to save
:param result: Result object to save
:param save_format: Format to save the evaluations in.
:param save_path: Path to save the evaluations to.
If None, the evaluations will not be saved.
:return: The serialized evaluations
"""
# serialize the evaluations
evaluations: List[Evaluation] = prep_for_serialization(evaluations)
# convert to ordered dicts to preserve order
evaluations: List[OrderedDict] = evaluations_to_dicts(evaluations)
# prepare the Result object for serialization
result: Result = prep_for_serialization(result)
if save_format == "json":
return _save_to_json(evaluations, save_path)
_save_to_json(result, save_path)
elif save_format == "yaml":
return _save_to_yaml(evaluations, save_path)
_save_to_yaml(result, save_path)
else:
NotImplementedError("Currently only json and yaml formats are supported")


def _save_to_json(evaluations: List[OrderedDict], save_path: Optional[str]) -> str:
data = json.dumps(evaluations, indent=4)
if save_path:
_save(data, save_path, expected_ext=".json")
return data


def _save_to_yaml(evaluations: List[OrderedDict], save_path: Optional[str]) -> str:
# required to properly process OrderedDicts
yaml.add_representer(
OrderedDict,
lambda dumper, data: dumper.represent_mapping(
"tag:yaml.org,2002:map", data.items()
),
)
data = yaml.dump(evaluations, default_flow_style=False)
if save_path:
_save(data, save_path, expected_ext=".yaml")
return data
def _save_to_json(result: Result, save_path: str):
_save(result.json(), save_path, expected_ext=".json")


def evaluations_to_dicts(evaluations: List[Evaluation]):
return [OrderedDict(**evaluation.dict()) for evaluation in evaluations]
def _save_to_yaml(result: Result, save_path: str):
_save(yaml.dump(result.dict()), save_path, expected_ext=".yaml")


def _save(data: str, save_path: str, expected_ext: str):
Expand Down
5 changes: 5 additions & 0 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,11 @@ def engine_forward(
generated_tokens.append(token)
generated_logits.append(logits)

if session.total_num_processed_tokens >= session.capacity:
# if the kv cache is full, stop generation
finished_reason.append(FinishReason.CAPACITY)
break

if (
token == self.tokenizer.eos_token_id
and not self.force_max_tokens
Expand Down
2 changes: 1 addition & 1 deletion src/deepsparse/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from deepsparse.generated_version import is_enterprise, is_release, splash, version
except Exception:
# otherwise, fall back to version info in this file
version = "1.6.0"
version = "1.7.0"
is_release = False
is_enterprise = False
splash = (
Expand Down
12 changes: 7 additions & 5 deletions src/deepsparse/yolov8/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,24 @@
# limitations under the License.

"""
Usage: deepsparse.object_detection.annotate [OPTIONS]
Usage: deepsparse.yolov8.annotate [OPTIONS]
Annotation Script for YOLOv8 with DeepSparse
Options:
--model_filepath, --model-filepath TEXT
Path/SparseZoo stub to the model file to be
used for annotation
--source TEXT File path to an image or directory of image
--subtask TEXT A subtask to run the YOLOv8 model on.
Defaults to 'detection'
--source TEXT File path to image or directory of .jpg
files, a .mp4 video, or an integer (i.e. 0)
for webcam [required]
--engine [deepsparse|onnxruntime|torch]
Inference engine backend to run on. Choices
are 'deepsparse', 'onnxruntime', and
'torch'. Default is 'deepsparse'
--model_input_image_shape, --model-input-shape INTEGER...
--model_input_image_shape, --model-input-image-shape INTEGER...
Image shape to override model with for
inference, must be two integers
--num_cores, --num-cores INTEGER
Expand All @@ -51,8 +53,8 @@
then it will be ignored
--no_save, --no-save Set flag when source is from webcam to not
save results.Not supported for non-webcam
sources [default: False]
--help Show this message and exit
sources
--help Show this message and exit.
#######
Examples:
Expand Down
127 changes: 13 additions & 114 deletions tests/deepsparse/evaluation/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
EvalSample,
Evaluation,
Metric,
save_evaluations,
Result,
save_result,
)


Expand Down Expand Up @@ -56,124 +57,22 @@ def evaluations():


@pytest.fixture()
def evaluations_json():
return """[
{
"task": "task_1",
"dataset": {
"type": "type_1",
"name": "name_1",
"config": "config_1",
"split": "split_1"
},
"metrics": [
{
"name": "metric_name_1",
"value": 1.0
}
],
"samples": [
{
"input": [
[
5
]
],
"output": 5
}
]
},
{
"task": "task_2",
"dataset": {
"type": "type_2",
"name": "name_2",
"config": "config_2",
"split": "split_2"
},
"metrics": [
{
"name": "metric_name_2",
"value": 2.0
},
{
"name": "metric_name_3",
"value": 3.0
}
],
"samples": [
{
"input": [
[
10.0
]
],
"output": 10.0
},
{
"input": [
[
20.0
]
],
"output": 20.0
}
]
}
]""" # noqa: E501
def result(evaluations):
return Result(formatted=evaluations, raw="dummy_raw_evaluation")


@pytest.fixture()
def evaluations_yaml():
return """- task: task_1
dataset:
config: config_1
name: name_1
split: split_1
type: type_1
metrics:
- name: metric_name_1
value: 1.0
samples:
- input:
- - 5
output: 5
- task: task_2
dataset:
config: config_2
name: name_2
split: split_2
type: type_2
metrics:
- name: metric_name_2
value: 2.0
- name: metric_name_3
value: 3.0
samples:
- input:
- - 10.0
output: 10.0
- input:
- - 20.0
output: 20.0
"""


def test_serialize_evaluation_json(tmp_path, evaluations, evaluations_json):
def test_serialize_result_json(tmp_path, result):
path_to_file = tmp_path / "result.json"
evaluations_serialized = save_evaluations(
evaluations=evaluations, save_format="json", save_path=path_to_file.as_posix()
)
save_result(result=result, save_format="json", save_path=path_to_file.as_posix())

with open(path_to_file.as_posix(), "r") as f:
assert json.load(f)
assert evaluations_serialized == evaluations_json
reloaded_results = json.load(f)
assert reloaded_results == result.dict()


def test_serialize_evaluation_yaml(tmp_path, evaluations, evaluations_yaml):
def test_serialize_result_yaml(tmp_path, result):
path_to_file = tmp_path / "result.yaml"
evaluations_serialized = save_evaluations(
evaluations=evaluations, save_format="yaml", save_path=path_to_file.as_posix()
)
save_result(result=result, save_format="yaml", save_path=path_to_file.as_posix())
with open(path_to_file.as_posix(), "r") as f:
assert yaml.safe_load(f)
assert evaluations_serialized == evaluations_yaml
reloaded_results = yaml.safe_load(f)
assert reloaded_results == result.dict()
Loading

0 comments on commit e0a9dee

Please sign in to comment.