Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference UX, accept input data #1285

Merged
merged 24 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6c8dfb3
[deepsparse.infer] UX improvements, data only mode
bfineran Sep 25, 2023
b4b7ec6
fix bug on main
bfineran Sep 25, 2023
17168f6
draft, load files line by line, return iter, save up memory
horheynm Sep 26, 2023
84b03f8
add inference
horheynm Sep 26, 2023
8a69c5f
pass passing in files
horheynm Sep 26, 2023
4091503
latest changes'
horheynm Sep 26, 2023
b0f65af
revert
horheynm Sep 26, 2023
8ee765b
make new folder for inderence
horheynm Sep 26, 2023
8a47e01
allow input to pass thru cli
horheynm Sep 26, 2023
6957067
Merge branch 'main' into infer-ux-fixes
horheynm Sep 26, 2023
b429917
Update src/deepsparse/transformers/inference/infer.py
horheynm Sep 27, 2023
1dc2ee3
remove hardcoded
horheynm Sep 27, 2023
86a2daf
better error message
horheynm Sep 27, 2023
939c6bc
clean up
horheynm Sep 27, 2023
274570b
Merge branch 'main' into infer-ux-fixes
horheynm Sep 27, 2023
7b1edfa
clean up, check kwargs
horheynm Sep 27, 2023
064013a
Merge branch 'infer-ux-fixes' of github.com:neuralmagic/deepsparse in…
horheynm Sep 27, 2023
8cab34f
Merge branch 'main' into infer-ux-fixes
horheynm Sep 27, 2023
4699837
get rid of breakpoint()
horheynm Sep 27, 2023
218b584
Merge branch 'infer-ux-fixes' of github.com:neuralmagic/deepsparse in…
horheynm Sep 27, 2023
ff4b48f
return type
horheynm Sep 27, 2023
8243740
Merge branch 'main' into infer-ux-fixes
horheynm Oct 2, 2023
2a4b972
typo
horheynm Oct 2, 2023
ad9e96d
Merge branch 'infer-ux-fixes' of github.com:neuralmagic/deepsparse in…
horheynm Oct 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _setup_entry_points() -> Dict:
"console_scripts": [
f"deepsparse.transformers.run_inference={data_api_entrypoint}",
f"deepsparse.transformers.eval_downstream={eval_downstream}",
"deepsparse.infer=deepsparse.transformers.infer:main",
"deepsparse.infer=deepsparse.transformers.inference.infer:main",
"deepsparse.debug_analysis=deepsparse.debug_analysis:main",
"deepsparse.analyze=deepsparse.analyze:main",
"deepsparse.check_hardware=deepsparse.cpu:print_hardware_capability",
Expand Down
13 changes: 13 additions & 0 deletions src/deepsparse/transformers/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@
deepsparse.infer models/llama/deployment \
--task text-generation
"""

from typing import Iterator, Optional

import click

from deepsparse import Pipeline
from deepsparse.tasks import SupportedTasks
from deepsparse.transformers.inference.prompt_parser import PromptParser


@click.command(
Expand All @@ -75,6 +79,14 @@
)
)
@click.argument("model_path", type=str)
@click.option(
"--data",
type=str,
default=None,
help="Path to .txt, .csv, .json, or .jsonl file to load data from"
"If provided, runs inference over the entire dataset. If not provided "
"runs an interactive inference session in the console. Default None.",
)
@click.option(
"--sequence_length",
type=int,
Expand Down Expand Up @@ -112,6 +124,7 @@
)
def main(
model_path: str,
data: Optional[str],
sequence_length: int,
sampling_temperature: float,
prompt_sequence_length: int,
Expand All @@ -131,31 +144,75 @@ def main(
task=task, # let pipeline determine if task is supported
model_path=model_path,
sequence_length=sequence_length,
sampling_temperature=sampling_temperature,
prompt_sequence_length=prompt_sequence_length,
)

# continue prompts until a keyboard interrupt
while True:
input_text = input("User: ")
pipeline_inputs = {"prompt": [input_text]}

if SupportedTasks.is_chat(task):
pipeline_inputs["session_ids"] = session_ids

response = pipeline(**pipeline_inputs)
print("Bot: ", response.generations[0].text)
if show_tokens_per_sec:
times = pipeline.timer_manager.times
prefill_speed = (
1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
)
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
if data:
for prompt, prompt_kwargs in _iter_prompt_from_file(data):
prompt_kwargs = {}
_run_inference(
pipeline,
sampling_temperature,
task,
session_ids,
show_tokens_per_sec,
prompt_sequence_length,
prompt,
**prompt_kwargs,
)
return

# continue prompts until a keyboard interrupt
while data is None: # always True in interactive Mode
prompt_input = input(">>> ")
_run_inference(
pipeline,
sampling_temperature,
task,
session_ids,
show_tokens_per_sec,
prompt_sequence_length,
prompt_input,
)


def _iter_prompt_from_file(data: str) -> Iterator:
parser = PromptParser(data)
return parser.parse_as_iterable()


def _run_inference(
pipeline,
sampling_temperature,
task,
session_ids,
show_tokens_per_sec,
prompt_sequence_length,
prompt,
**kwargs,
):
pipeline_inputs = dict(
prompt=[prompt],
sampling_temperature=sampling_temperature,
horheynm marked this conversation as resolved.
Show resolved Hide resolved
# **kwargs,
)
if SupportedTasks.is_chat(task):
pipeline_inputs["session_ids"] = session_ids

response = pipeline(**pipeline_inputs)
print("\n", response.generations[0].text)

if show_tokens_per_sec:
times = pipeline.timer_manager.times
prefill_speed = (
1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
)
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
)


if __name__ == "__main__":
Expand Down
81 changes: 81 additions & 0 deletions src/deepsparse/transformers/inference/prompt_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import csv
import json
import os
from enum import Enum


class InvalidPromptSourceDirectoryException(Exception):
pass


class PromptParser:
class Extentions(Enum):
TEXT = ".txt"
CSV = ".csv"
JSON = ".json"
JSONL = ".jsonl"

def __init__(self, filename: str):
self.extention: self.Extentions = self._validate_and_return_extention(filename)
self.filename: str = filename

def parse_as_iterable(self):

if self.extention == self.Extentions.TEXT:
return self._parse_text()
if self.extention == self.Extentions.CSV:
return self._parse_csv()
if self.extention == self.Extentions.JSON:
return self._parse_json_list()
if self.extention == self.Extentions.JSONL:
return self._parse_jsonl()

def _parse_text(self):
with open(self.filename, "r") as file:
for line in file:
yield line.strip(), {}

def _parse_csv(self):
with open(self.filename, "r", newline="", encoding="utf-8-sig") as file:
reader = csv.DictReader(file)
for row in reader:
yield row.get("prompt"), row

def _parse_json_list(self):
with open(self.filename, "r") as file:
json_list = json.load(file)
for json_object in json_list:
yield json_object.get("prompt"), json_object

def _parse_jsonl(self):
with open(self.filename, "r") as file:
for jsonl in file:
jsonl_object = json.loads(jsonl)
yield jsonl_object.get("prompt"), jsonl_object

def _validate_and_return_extention(self, filename: str):
if os.path.exists(filename):

for extention in self.Extentions:
if filename.endswith(extention.value):
return extention

raise InvalidPromptSourceDirectoryException(
f"{filename} is not a valid source extract batched prompts"
)
raise FileNotFoundError
Loading