Skip to content

Commit

Permalink
Update label mapping for deepsparse.transformers.eval_downstream (#323)…
Browse files Browse the repository at this point in the history
… (#324)

* Update label mapping for deepsparse.transformers.eval_downstream

* Fix MNLI as well
  • Loading branch information
mgoin committed Apr 18, 2022
1 parent 500d132 commit 01a427a
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions src/deepsparse/transformers/eval_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,22 @@ def mnli_eval(args):
)
print(f"Engine info: {text_classify.model}")

label_map = {"entailment": 0, "neutral": 1, "contradiction": 2}

for idx, sample in enumerate(tqdm(mnli_matched)):
pred = text_classify(sample["premise"], sample["hypothesis"])
pred = text_classify([[sample["premise"], sample["hypothesis"]]])
mnli_metrics.add_batch(
predictions=[int(pred[0]["label"].split("_")[-1])],
predictions=[label_map.get(pred[0]["label"])],
references=[sample["label"]],
)

if args.max_samples and idx > args.max_samples:
break

for idx, sample in enumerate(tqdm(mnli_mismatched)):
pred = text_classify(sample["premise"], sample["hypothesis"])
pred = text_classify([[sample["premise"], sample["hypothesis"]]])
mnli_metrics.add_batch(
predictions=[int(pred[0]["label"].split("_")[-1])],
predictions=[label_map.get(pred[0]["label"])],
references=[sample["label"]],
)

Expand All @@ -161,11 +163,13 @@ def qqp_eval(args):
)
print(f"Engine info: {text_classify.model}")

label_map = {"not_duplicate": 0, "duplicate": 1}

for idx, sample in enumerate(tqdm(qqp)):
pred = text_classify([[sample["question1"], sample["question2"]]])

qqp_metrics.add_batch(
predictions=[int(pred[0]["label"].split("_")[-1])],
predictions=[label_map.get(pred[0]["label"])],
references=[sample["label"]],
)

Expand All @@ -190,13 +194,15 @@ def sst2_eval(args):
)
print(f"Engine info: {text_classify.model}")

label_map = {"negative": 0, "positive": 1}

for idx, sample in enumerate(tqdm(sst2)):
pred = text_classify(
sample["sentence"],
)

sst2_metrics.add_batch(
predictions=[int(pred[0]["label"].split("_")[-1])],
predictions=[label_map.get(pred[0]["label"])],
references=[sample["label"]],
)

Expand Down Expand Up @@ -229,6 +235,7 @@ def parse_args():
"--dataset",
type=str,
choices=list(SUPPORTED_DATASETS.keys()),
required=True,
)

parser.add_argument(
Expand Down

0 comments on commit 01a427a

Please sign in to comment.