Skip to content

Commit

Permalink
eval_downstream.py
Browse files Browse the repository at this point in the history
  • Loading branch information
natuan committed Jun 27, 2022
1 parent c581741 commit fb673ec
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/deepsparse/transformers/eval_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def mnli_eval(args):
)
print(f"Engine info: {text_classify.engine}")

label_map = _get_label2id(text_classify.config_path)
try:
label_map = _get_label2id(text_classify.config_path)
except KeyError:
label_map = {"entailment": 0, "neutral": 1, "contradiction": 2}

for idx, sample in _enumerate_progress(mnli_matched, args.max_samples):
pred = text_classify([[sample["premise"], sample["hypothesis"]]])
Expand Down Expand Up @@ -162,7 +165,10 @@ def qqp_eval(args):
)
print(f"Engine info: {text_classify.engine}")

label_map = _get_label2id(text_classify.config_path)
try:
label_map = _get_label2id(text_classify.config_path)
except KeyError:
label_map = {"not_duplicate": 0, "duplicate": 1, "LABEL_0": 0, "LABEL_1": 1}

for idx, sample in _enumerate_progress(qqp, args.max_samples):
pred = text_classify([[sample["question1"], sample["question2"]]])
Expand Down Expand Up @@ -193,7 +199,10 @@ def sst2_eval(args):
)
print(f"Engine info: {text_classify.engine}")

label_map = _get_label2id(text_classify.config_path)
try:
label_map = _get_label2id(text_classify.config_path)
except KeyError:
label_map = {"negative": 0, "positive": 1, "LABEL_0": 0, "LABEL_1": 1}

for idx, sample in _enumerate_progress(sst2, args.max_samples):
pred = text_classify(
Expand Down

0 comments on commit fb673ec

Please sign in to comment.