Skip to content

Commit

Permalink
Inference UX, accept input data (#1285)
Browse files Browse the repository at this point in the history
* [deepsparse.infer] UX improvements, data only mode

* fix bug on main

* draft, load files line by line, return iter, save up memory

* add inference

* pass passing in files

* latest changes'

* revert

* make new folder for inderence

* allow input to pass thru cli

* Update src/deepsparse/transformers/inference/infer.py

Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>

* remove hardcoded

* better error message

* clean up

* clean up, check kwargs

* get rid of breakpoint()

* return type

* typo

---------

Co-authored-by: Benjamin <ben@neuralmagic.com>
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
  • Loading branch information
3 people committed Oct 2, 2023
1 parent 8d103b0 commit 5e425c9
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 23 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _setup_entry_points() -> Dict:
"console_scripts": [
f"deepsparse.transformers.run_inference={data_api_entrypoint}",
f"deepsparse.transformers.eval_downstream={eval_downstream}",
"deepsparse.infer=deepsparse.transformers.infer:main",
"deepsparse.infer=deepsparse.transformers.inference.infer:main",
"deepsparse.debug_analysis=deepsparse.debug_analysis:main",
"deepsparse.analyze=deepsparse.analyze:main",
"deepsparse.check_hardware=deepsparse.cpu:print_hardware_capability",
Expand Down
13 changes: 13 additions & 0 deletions src/deepsparse/transformers/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@
deepsparse.infer models/llama/deployment \
--task text-generation
"""

from typing import Optional

import click

from deepsparse import Pipeline
from deepsparse.tasks import SupportedTasks
from deepsparse.transformers.inference.prompt_parser import PromptParser


@click.command(
Expand All @@ -75,6 +79,14 @@
)
)
@click.argument("model_path", type=str)
@click.option(
"--data",
type=str,
default=None,
help="Path to .txt, .csv, .json, or .jsonl file to load data from"
"If provided, runs inference over the entire dataset. If not provided "
"runs an interactive inference session in the console. Default None.",
)
@click.option(
"--sequence_length",
type=int,
Expand Down Expand Up @@ -112,6 +124,7 @@
)
def main(
model_path: str,
data: Optional[str],
sequence_length: int,
sampling_temperature: float,
prompt_sequence_length: int,
Expand All @@ -128,34 +141,76 @@ def main(
session_ids = "chatbot_cli_session"

pipeline = Pipeline.create(
task=task, # let pipeline determine if task is supported
task=task, # let the pipeline determine if task is supported
model_path=model_path,
sequence_length=sequence_length,
sampling_temperature=sampling_temperature,
prompt_sequence_length=prompt_sequence_length,
)

# continue prompts until a keyboard interrupt
while True:
input_text = input("User: ")
pipeline_inputs = {"prompt": [input_text]}

if SupportedTasks.is_chat(task):
pipeline_inputs["session_ids"] = session_ids

response = pipeline(**pipeline_inputs)
print("Bot: ", response.generations[0].text)
if show_tokens_per_sec:
times = pipeline.timer_manager.times
prefill_speed = (
1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
)
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
if data:
prompt_parser = PromptParser(data)
default_prompt_kwargs = {
"sequence_length": sequence_length,
"sampling_temperature": sampling_temperature,
"prompt_sequence_length": prompt_sequence_length,
"show_tokens_per_sec": show_tokens_per_sec,
}

for prompt_kwargs in prompt_parser.parse_as_iterable(**default_prompt_kwargs):
_run_inference(
task=task,
pipeline=pipeline,
session_ids=session_ids,
**prompt_kwargs,
)
return

# continue prompts until a keyboard interrupt
while data is None: # always True in interactive Mode
prompt = input(">>> ")
_run_inference(
pipeline,
sampling_temperature,
task,
session_ids,
show_tokens_per_sec,
prompt_sequence_length,
prompt,
)


def _run_inference(
pipeline,
sampling_temperature,
task,
session_ids,
show_tokens_per_sec,
prompt_sequence_length,
prompt,
**kwargs,
):
pipeline_inputs = dict(
prompt=[prompt],
temperature=sampling_temperature,
**kwargs,
)
if SupportedTasks.is_chat(task):
pipeline_inputs["session_ids"] = session_ids

response = pipeline(**pipeline_inputs)
print("\n", response.generations[0].text)

if show_tokens_per_sec:
times = pipeline.timer_manager.times
prefill_speed = (
1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
)
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
)


if __name__ == "__main__":
Expand Down
108 changes: 108 additions & 0 deletions src/deepsparse/transformers/inference/prompt_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import csv
import json
import os
from enum import Enum
from typing import Iterator


class InvalidPromptSourceDirectoryException(Exception):
pass


class UnableToParseExtentsonException(Exception):
pass


def parse_value_to_appropriate_type(value: str):
if value.isdigit():
return int(value)
if "." in str(value) and all(part.isdigit() for part in value.split(".", 1)):
return float(value)
if value.lower() == "true":
return True
if value.lower() == "false":
return False
return value


class PromptParser:
class Extensions(Enum):
TEXT = ".txt"
CSV = ".csv"
JSON = ".json"
JSONL = ".jsonl"

def __init__(self, filename: str):
self.extention: self.Extensions = self._validate_and_return_extention(filename)
self.filename: str = filename

def parse_as_iterable(self, **kwargs) -> Iterator:
if self.extention == self.Extensions.TEXT:
return self._parse_text(**kwargs)
if self.extention == self.Extensions.CSV:
return self._parse_csv(**kwargs)
if self.extention == self.Extensions.JSON:
return self._parse_json_list(**kwargs)
if self.extention == self.Extensions.JSONL:
return self._parse_jsonl(**kwargs)

raise UnableToParseExtentsonException(
f"Parser for {self.extention} does not exist"
)

def _parse_text(self, **kwargs):
with open(self.filename, "r") as file:
for line in file:
kwargs["prompt"] = line.strip()
yield kwargs

def _parse_csv(self, **kwargs):
with open(self.filename, "r", newline="", encoding="utf-8-sig") as file:
reader = csv.DictReader(file)
for row in reader:
for key, value in row.items():
kwargs.update({key: parse_value_to_appropriate_type(value)})
yield kwargs

def _parse_json_list(self, **kwargs):
with open(self.filename, "r") as file:
json_list = json.load(file)
for json_object in json_list:
kwargs.update(json_object)
yield kwargs

def _parse_jsonl(self, **kwargs):
with open(self.filename, "r") as file:
for jsonl in file:
jsonl_object = json.loads(jsonl)
kwargs.update(jsonl_object)
yield kwargs

def _validate_and_return_extention(self, filename: str):
if os.path.exists(filename):

for extention in self.Extensions:
if filename.endswith(extention.value):
return extention

raise InvalidPromptSourceDirectoryException(
f"{filename} is not compatible. Select file that has "
"extension from "
f"{[key.name for key in self.Extensions]}"
)
raise FileNotFoundError

0 comments on commit 5e425c9

Please sign in to comment.