Skip to content

Commit

Permalink
Perplexity Eval for Text Generation Models (#1073)
Browse files Browse the repository at this point in the history
* initial commit

* Update src/deepsparse/license.py

* limit to 150mb

* ready to review

* initial commit

* [Codegen][ORT][Static Seq Length] TextGenerationPipeline (#946)

* initial commit

* coreys simplifications

* finishing the second model static

* ready, time for beautification

* ready for review

* moved the code to examples

* fix eos logic

* add argument num_tokens_to_generate

* [CodeGen][Documentation] (#956)

* initial commit

* coreys simplifications

* finishing the second model static

* ready, time for beautification

* ready for review

* moved the code to examples

* fix eos logic

* add argument num_tokens_to_generate

* initial commit

* change order

* Update examples/codegen/README.md

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>

---------

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>

* reimplementation for generative pipelines

* restore text generation from examples

* [CodeGen] ONNX model loading to support >2Gb models / two engines (#991)

* refactor sucessfull

* Pipeline fully refactored, time to test engine support. Note: Sliding window not yet implemented!

* First iteration with Sage

* Apply suggestions from code review

* ORT agrees with the Engine. But they both give not entirely correct result. Hey, this is good news still

* dynamic ORT vs static DS

* pipeline handles OPT multitoken pass

* fixes to get static pipeline a little further along

* adjust shapes and slicing to enable static autoregressive pass - ISSUE: tokens past the base seq len are repeated

* migrate from cache_length to positions input

* got if working for multitoken + single token scenario

* cleanup the pipeline

* further cleanup post merge

* Pipeline working for single-token inference only

* do not load the onnx model with external files twice

* pipeline never redundantly saves the external data + more robust tokenizer

* Stop saving tmp files, otherwise the engine looks for external files in the wrong place

* Left pad support

* cleanup

* cleanup2

* Add in pipeline timing

* add in force tokens logic

* remove input validation for text generation pipelines

* remove multitoken support for now

* remove kv cache engine and other fixes

* nest input shape override

* comment out input shape override

* add non batch override for ORT

* clean up generation pipeline

* initial commit

* Update src/deepsparse/license.py

* limit to 150mb

* ready to review

* fix the erronous Makefile

* perhaps fixed GHA

* take into consideration that GHA creates four files

* initial commit

* tested with actual model

* remove val_inp argument

* Update README.md

* Apply suggestions from code review

* Update README.md

* [BugFix] Update deepsparse dockerfile (#1069)

* Remove autoinstall triggering commands

* Fix typo

* initial implementation

* working implementation for pipeline input

* [Fix] Fix CLI benchmark errors (#1071)

* initial commit

* ready for review

* Update src/deepsparse/utils/onnx.py

* Clean a typo in the pipeline code

* cleanup the old files

* Update src/deepsparse/transformers/engines/nl_decoder_engine.py

* ready for review

* ready for testing

* assert proper padding on pipeline init

* now also supporting kv cache perplexity. time for cleanup

* ready for review

* correctly print engine info

* work with left padding of the tokenizer

* quality

* fix the multitoken inference

---------

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>
Co-authored-by: Mark Kurtz <mark.kurtz@neuralmagic.com>
Co-authored-by: Benjamin <ben@neuralmagic.com>
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
  • Loading branch information
5 people committed Jul 5, 2023
1 parent 0809aea commit 10c804a
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 30 deletions.
8 changes: 4 additions & 4 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,7 @@ def __call__(
else:
logits = out[0]

B, S, V = logits.shape # batch, sequence, vocab
logits = logits[:, -1, :].reshape(B, 1, V) # only take the last token

token = self.generate_token(logits=logits)
token = self.generate_token(logits=logits[:, -1, :])

return token, logits

Expand Down Expand Up @@ -253,6 +250,9 @@ def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:

return numpy.random.choice(len(probs), p=probs)

def __str__(self):
return f"{self.__class__.__name__}: {self.engine}"

def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:
# initialize empty kv cache of size
# (batch_size, num_attention_heads, length, hidden_dims)
Expand Down
42 changes: 37 additions & 5 deletions src/deepsparse/transformers/eval_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,43 @@
import numpy
from tqdm.auto import tqdm

from deepsparse import Pipeline
from deepsparse.transformers.metrics import PrecisionRecallF1
from deepsparse import DEEPSPARSE_ENGINE, ORT_ENGINE, Pipeline
from deepsparse.transformers.metrics import Perplexity, PrecisionRecallF1


from datasets import load_dataset, load_metric # isort: skip

DEEPSPARSE_ENGINE = "deepsparse"
ORT_ENGINE = "onnxruntime"

def perplexity_eval(args, batch_size=16, dataset_name="openai_humaneval"):
dataset = load_dataset(dataset_name)["test"]

text_generation = Pipeline.create(
task="text-generation",
model_path=args.model_path,
engine_type=args.engine,
num_cores=args.num_cores,
sequence_length=args.max_sequence_length,
prompt_processing_sequence_length=args.max_sequence_length,
max_generated_tokens=1,
remove_special_tokens_from_prompt=False,
)
perplexity_metrics = Perplexity(pipeline=text_generation, batch_size=batch_size)
active_engines = [
engine
for engine in [text_generation.engine, text_generation.multitoken_engine]
if engine
]
print("Engine info: ")
[print(f"{engine}\n") for engine in active_engines]
predictions = []
for idx, sample in _enumerate_progress(dataset, args.max_samples):
predictions.append(sample["prompt"] + sample["canonical_solution"])
if len(predictions) == batch_size:
perplexity_metrics.add_batch(predictions)
predictions = []
if args.max_samples and idx >= args.max_samples:
break
return perplexity_metrics


def qa_eval(args, dataset_name="squad"):
Expand Down Expand Up @@ -443,11 +472,14 @@ def _split_train_val(train_dataset, val_ratio, seed=42):
"imdb": imdb_eval,
"conll2003": conll2003_eval,
"go_emotions": go_emotions_eval,
"openai_humaneval": perplexity_eval,
}


def parse_args():
parser = argparse.ArgumentParser(
# TODO: It is not BERT anymore, should we
# have another script or modify the existing one?
description="Evaluate a BERT ONNX model on a downstream dataset"
)
parser.add_argument(
Expand All @@ -461,9 +493,9 @@ def parse_args():
parser.add_argument(
"-d",
"--dataset",
type=str,
choices=list(SUPPORTED_DATASETS.keys()),
required=True,
type=str,
)
parser.add_argument(
"-v",
Expand Down
103 changes: 102 additions & 1 deletion src/deepsparse/transformers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,119 @@
"""


from typing import Dict, Optional
from typing import Any, Dict, List, Optional

import numpy
from tqdm import tqdm

import torch
from deepsparse import Pipeline
from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline
from sklearn.metrics import precision_recall_fscore_support


__all__ = [
"PrecisionRecallF1",
"Perplexity",
]


class Perplexity:
def __init__(self, pipeline: Pipeline, batch_size: int = 16):
"""
Given the pipeline, compute the perplexity of the model
on the given text input.
Code adapted from:
https://huggingface.co/spaces/evaluate-metric/perplexity/blob/main/perplexity.py # noqa: E501
:param pipeline: The pipeline to use for text generation
:param batch_size: The batch size to split the input text into
non-overlapping batches
"""
if not isinstance(pipeline, TextGenerationPipeline):
raise ValueError(
"Perplexity can only be computed for text generation pipelines"
)
self._pipeline = pipeline
self._batch_size = batch_size
self._sequence_length = pipeline.sequence_length
self._loss_fct = torch.nn.CrossEntropyLoss(reduction="none")

self.perplexities = []

def add_batch(self, predictions: List[str]):
"""
Run the model on the given input sequences and compute the perplexity.
The resulting perplexity is appended to the list of perplexities.
:param predictions: The predictions to compute perplexity on
"""
# tokenize the input text
encodings = self._pipeline.tokenizer(
predictions,
return_attention_mask=True,
max_length=self._sequence_length,
truncation=True,
padding="max_length",
)

encoded_texts = encodings["input_ids"]
attention_masks = encodings["attention_mask"]

for start_index in tqdm(range(0, len(encoded_texts), self._batch_size)):
end_index = min(start_index + self._batch_size, len(encoded_texts))
encoded_batch = encoded_texts[start_index:end_index]
attention_mask = attention_masks[start_index:end_index]

out = self._pipeline(
sequences=predictions, return_logits=True, truncate=True
)
logits = out.logits

labels = encoded_batch
labels = numpy.stack(labels)
attention_mask = numpy.stack(attention_mask)

# because the tokenizer is left padded, we need to move the meaningful
# part of the logits and labels to the right
num_padded_entries = attention_mask.sum(axis=1)

# shift the values at num_paddings to the top of the array using roll
for i, num_padded in enumerate(num_padded_entries):
logits[i] = numpy.roll(logits[i], num_padded, axis=0)
labels[i] = numpy.roll(labels[i], num_padded, axis=0)
attention_mask[i] = numpy.roll(attention_mask[i], num_padded, axis=0)

# shift logits and labels create the input and target for the loss function
shift_logits = logits[:, :-1, :]
shift_labels = labels[:, 1:]
shift_attention_mask_batch = attention_mask[:, 1:]

# compute perplexity for this batch
perplexity_batch = torch.exp(
(
self._loss_fct(
torch.tensor(shift_logits.transpose(0, 2, 1)),
torch.tensor(shift_labels),
)
* torch.tensor(shift_attention_mask_batch)
).sum(1)
/ torch.tensor(shift_attention_mask_batch).sum(1)
)
self.perplexities.extend(perplexity_batch.numpy().tolist())

def compute(self) -> Dict[str, Any]:
"""
:return: A dictionary containing the mean perplexity
and the list of perplexities
"""
return {
"mean_perplexity": numpy.mean(self.perplexities),
"perplexities": self.perplexities,
}


class PrecisionRecallF1:
def __init__(self, id_to_label: Optional[Dict[int, str]] = None):
self._id_to_label = id_to_label
Expand Down
Loading

0 comments on commit 10c804a

Please sign in to comment.