Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference UX, accept input data #1285

Merged
merged 24 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6c8dfb3
[deepsparse.infer] UX improvements, data only mode
bfineran Sep 25, 2023
b4b7ec6
fix bug on main
bfineran Sep 25, 2023
17168f6
draft, load files line by line, return iter, save up memory
horheynm Sep 26, 2023
84b03f8
add inference
horheynm Sep 26, 2023
8a69c5f
pass passing in files
horheynm Sep 26, 2023
4091503
latest changes'
horheynm Sep 26, 2023
b0f65af
revert
horheynm Sep 26, 2023
8ee765b
make new folder for inderence
horheynm Sep 26, 2023
8a47e01
allow input to pass thru cli
horheynm Sep 26, 2023
6957067
Merge branch 'main' into infer-ux-fixes
horheynm Sep 26, 2023
b429917
Update src/deepsparse/transformers/inference/infer.py
horheynm Sep 27, 2023
1dc2ee3
remove hardcoded
horheynm Sep 27, 2023
86a2daf
better error message
horheynm Sep 27, 2023
939c6bc
clean up
horheynm Sep 27, 2023
274570b
Merge branch 'main' into infer-ux-fixes
horheynm Sep 27, 2023
7b1edfa
clean up, check kwargs
horheynm Sep 27, 2023
064013a
Merge branch 'infer-ux-fixes' of github.com:neuralmagic/deepsparse in…
horheynm Sep 27, 2023
8cab34f
Merge branch 'main' into infer-ux-fixes
horheynm Sep 27, 2023
4699837
get rid of breakpoint()
horheynm Sep 27, 2023
218b584
Merge branch 'infer-ux-fixes' of github.com:neuralmagic/deepsparse in…
horheynm Sep 27, 2023
ff4b48f
return type
horheynm Sep 27, 2023
8243740
Merge branch 'main' into infer-ux-fixes
horheynm Oct 2, 2023
2a4b972
typo
horheynm Oct 2, 2023
ad9e96d
Merge branch 'infer-ux-fixes' of github.com:neuralmagic/deepsparse in…
horheynm Oct 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _setup_entry_points() -> Dict:
"console_scripts": [
f"deepsparse.transformers.run_inference={data_api_entrypoint}",
f"deepsparse.transformers.eval_downstream={eval_downstream}",
"deepsparse.infer=deepsparse.transformers.infer:main",
"deepsparse.infer=deepsparse.transformers.inference.infer:main",
"deepsparse.debug_analysis=deepsparse.debug_analysis:main",
"deepsparse.analyze=deepsparse.analyze:main",
"deepsparse.check_hardware=deepsparse.cpu:print_hardware_capability",
Expand Down
13 changes: 13 additions & 0 deletions src/deepsparse/transformers/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@
deepsparse.infer models/llama/deployment \
--task text-generation
"""

from typing import Optional

import click

from deepsparse import Pipeline
from deepsparse.tasks import SupportedTasks
from deepsparse.transformers.inference.prompt_parser import PromptParser


@click.command(
Expand All @@ -75,6 +79,14 @@
)
)
@click.argument("model_path", type=str)
@click.option(
"--data",
type=str,
default=None,
help="Path to .txt, .csv, .json, or .jsonl file to load data from"
"If provided, runs inference over the entire dataset. If not provided "
"runs an interactive inference session in the console. Default None.",
)
@click.option(
"--sequence_length",
type=int,
Expand Down Expand Up @@ -112,6 +124,7 @@
)
def main(
model_path: str,
data: Optional[str],
sequence_length: int,
sampling_temperature: float,
prompt_sequence_length: int,
Expand All @@ -128,34 +141,76 @@ def main(
session_ids = "chatbot_cli_session"

pipeline = Pipeline.create(
task=task, # let pipeline determine if task is supported
task=task, # let the pipeline determine if task is supported
model_path=model_path,
sequence_length=sequence_length,
sampling_temperature=sampling_temperature,
prompt_sequence_length=prompt_sequence_length,
)

# continue prompts until a keyboard interrupt
while True:
input_text = input("User: ")
pipeline_inputs = {"prompt": [input_text]}

if SupportedTasks.is_chat(task):
pipeline_inputs["session_ids"] = session_ids

response = pipeline(**pipeline_inputs)
print("Bot: ", response.generations[0].text)
if show_tokens_per_sec:
times = pipeline.timer_manager.times
prefill_speed = (
1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
)
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
if data:
prompt_parser = PromptParser(data)
default_prompt_kwargs = {
"sequence_length": sequence_length,
"sampling_temperature": sampling_temperature,
"prompt_sequence_length": prompt_sequence_length,
"show_tokens_per_sec": show_tokens_per_sec,
}

for prompt_kwargs in prompt_parser.parse_as_iterable(**default_prompt_kwargs):
_run_inference(
task=task,
pipeline=pipeline,
session_ids=session_ids,
**prompt_kwargs,
)
return

# continue prompts until a keyboard interrupt
while data is None: # always True in interactive Mode
prompt = input(">>> ")
_run_inference(
pipeline,
sampling_temperature,
task,
session_ids,
show_tokens_per_sec,
prompt_sequence_length,
prompt,
)


def _run_inference(
pipeline,
sampling_temperature,
task,
session_ids,
show_tokens_per_sec,
prompt_sequence_length,
prompt,
**kwargs,
):
pipeline_inputs = dict(
prompt=[prompt],
temperature=sampling_temperature,
**kwargs,
)
if SupportedTasks.is_chat(task):
pipeline_inputs["session_ids"] = session_ids

response = pipeline(**pipeline_inputs)
print("\n", response.generations[0].text)

if show_tokens_per_sec:
times = pipeline.timer_manager.times
prefill_speed = (
1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
)
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
)


if __name__ == "__main__":
Expand Down
108 changes: 108 additions & 0 deletions src/deepsparse/transformers/inference/prompt_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import csv
import json
import os
from enum import Enum
from typing import Iterator


class InvalidPromptSourceDirectoryException(Exception):
pass


class UnableToParseExtentsonException(Exception):
pass


def parse_value_to_appropriate_type(value: str):
if value.isdigit():
return int(value)
if "." in str(value) and all(part.isdigit() for part in value.split(".", 1)):
return float(value)
if value.lower() == "true":
return True
if value.lower() == "false":
return False
return value


class PromptParser:
class Extensions(Enum):
TEXT = ".txt"
CSV = ".csv"
JSON = ".json"
JSONL = ".jsonl"

def __init__(self, filename: str):
self.extention: self.Extensions = self._validate_and_return_extention(filename)
self.filename: str = filename

def parse_as_iterable(self, **kwargs) -> Iterator:
if self.extention == self.Extensions.TEXT:
return self._parse_text(**kwargs)
if self.extention == self.Extensions.CSV:
return self._parse_csv(**kwargs)
if self.extention == self.Extensions.JSON:
return self._parse_json_list(**kwargs)
if self.extention == self.Extensions.JSONL:
return self._parse_jsonl(**kwargs)

raise UnableToParseExtentsonException(
f"Parser for {self.extention} does not exist"
)

def _parse_text(self, **kwargs):
with open(self.filename, "r") as file:
for line in file:
kwargs["prompt"] = line.strip()
yield kwargs

def _parse_csv(self, **kwargs):
with open(self.filename, "r", newline="", encoding="utf-8-sig") as file:
reader = csv.DictReader(file)
for row in reader:
for key, value in row.items():
kwargs.update({key: parse_value_to_appropriate_type(value)})
yield kwargs

def _parse_json_list(self, **kwargs):
with open(self.filename, "r") as file:
json_list = json.load(file)
for json_object in json_list:
kwargs.update(json_object)
yield kwargs

def _parse_jsonl(self, **kwargs):
with open(self.filename, "r") as file:
for jsonl in file:
jsonl_object = json.loads(jsonl)
kwargs.update(jsonl_object)
yield kwargs

def _validate_and_return_extention(self, filename: str):
if os.path.exists(filename):

for extention in self.Extensions:
if filename.endswith(extention.value):
return extention

raise InvalidPromptSourceDirectoryException(
f"{filename} is not compatible. Select file that has "
"extension from "
f"{[key.name for key in self.Extensions]}"
)
raise FileNotFoundError
Loading