Skip to content

Commit

Permalink
transformers eval support for conll2003 (#504) (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran committed Jul 12, 2022
1 parent 13d966f commit df7275f
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 3 deletions.
54 changes: 54 additions & 0 deletions src/deepsparse/transformers/eval_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,59 @@ def sst2_eval(args):
return sst2_metrics


def conll2003_eval(args):
# load qqp validation dataset and eval tool
conll2003 = load_dataset("conll2003")["validation"]
conll2003_metrics = load_metric("seqeval")

# load pipeline
token_classify = Pipeline.create(
task="token-classification",
model_path=args.onnx_filepath,
engine_type=args.engine,
num_cores=args.num_cores,
sequence_length=args.max_sequence_length,
)
print(f"Engine info: {token_classify.engine}")

ner_tag_map = {
"O": 0,
"B-PER": 1,
"I-PER": 2,
"B-ORG": 3,
"I-ORG": 4,
"B-LOC": 5,
"I-LOC": 6,
"B-MISC": 7,
"I-MISC": 8,
}
# map entity id and raw id from pipeline to NER tag
label_map = {label_id: ner_tag for ner_tag, label_id in ner_tag_map.items()}
label_map.update(
{
token_classify.config.id2label[label_id]: tag
for tag, label_id in ner_tag_map.items()
}
)

for idx, sample in _enumerate_progress(conll2003, args.max_samples):
if not sample["tokens"]:
continue # invalid dataset item, no tokens
pred = token_classify(inputs=sample["tokens"], is_split_into_words=True)
pred_ids = [label_map[prediction.entity] for prediction in pred.predictions[0]]
label_ids = [label_map[ner_tag] for ner_tag in sample["ner_tags"]]

conll2003_metrics.add_batch(
predictions=[pred_ids],
references=[label_ids],
)

if args.max_samples and idx >= args.max_samples:
break

return conll2003_metrics


def _enumerate_progress(dataset, max_steps):
progress_bar = tqdm(dataset, total=max_steps) if max_steps else tqdm(dataset)
return enumerate(progress_bar)
Expand All @@ -242,6 +295,7 @@ def _get_label2id(config_file_path):
"mnli": mnli_eval,
"qqp": qqp_eval,
"sst2": sst2_eval,
"conll2003": conll2003_eval,
}


Expand Down
40 changes: 37 additions & 3 deletions src/deepsparse/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ class TokenClassificationInput(BaseModel):
"a token_classification task"
)
)
is_split_into_words: bool = Field(
default=False,
description=(
"True if the input is a batch size 1 list of strings representing. "
"individual word tokens. Currently only supports batch size 1. "
"Default is False"
),
)


class TokenClassificationResult(BaseModel):
Expand Down Expand Up @@ -245,13 +253,17 @@ def process_inputs(
and dictionary containing offset mappings and special tokens mask to
be used during postprocessing
"""
if inputs.is_split_into_words and self.engine.batch_size != 1:
raise ValueError("is_split_into_words=True only supported for batch size 1")

tokens = self.tokenizer(
inputs.inputs,
return_tensors="np",
truncation=TruncationStrategy.LONGEST_FIRST.value,
padding=PaddingStrategy.MAX_LENGTH.value,
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
is_split_into_words=inputs.is_split_into_words,
)

offset_mapping = (
Expand All @@ -260,11 +272,29 @@ def process_inputs(
else [None] * len(inputs.inputs)
)
special_tokens_mask = tokens.pop("special_tokens_mask")

word_start_mask = None
if inputs.is_split_into_words:
# create mask for word in the split words where values are True
# if they are the start of a tokenized word
word_start_mask = []
word_ids = tokens.word_ids(batch_index=0)
previous_id = None
for word_id in word_ids:
if word_id is None:
continue
if word_id != previous_id:
word_start_mask.append(True)
previous_id = word_id
else:
word_start_mask.append(False)

postprocessing_kwargs = dict(
inputs=inputs,
tokens=tokens,
offset_mapping=offset_mapping,
special_tokens_mask=special_tokens_mask,
word_start_mask=word_start_mask,
)

return self.tokens_to_engine_input(tokens), postprocessing_kwargs
Expand All @@ -284,6 +314,7 @@ def process_engine_outputs(
tokens = kwargs["tokens"]
offset_mapping = kwargs["offset_mapping"]
special_tokens_mask = kwargs["special_tokens_mask"]
word_start_mask = kwargs["word_start_mask"]

predictions = [] # type: List[List[TokenClassificationResult]]

Expand All @@ -293,6 +324,7 @@ def process_engine_outputs(
scores = numpy.exp(current_entities) / numpy.exp(current_entities).sum(
-1, keepdims=True
)

pre_entities = self._gather_pre_entities(
inputs.inputs[entities_index],
input_ids,
Expand All @@ -303,9 +335,11 @@ def process_engine_outputs(
grouped_entities = self._aggregate(pre_entities)
# Filter anything that is in self.ignore_labels
current_results = [] # type: List[TokenClassificationResult]
for entity in grouped_entities:
if entity.get("entity") in self.ignore_labels or (
entity.get("entity_group") in self.ignore_labels
for entity_idx, entity in enumerate(grouped_entities):
if (
entity.get("entity") in self.ignore_labels
or (entity.get("entity_group") in self.ignore_labels)
or (word_start_mask and not word_start_mask[entity_idx])
):
continue
if entity.get("entity_group"):
Expand Down

0 comments on commit df7275f

Please sign in to comment.