Skip to content

Commit

Permalink
feat(qa_evaluator): Improve QA evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
ktagowski committed Aug 24, 2023
1 parent cb22afe commit 370c990
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 3 deletions.
2 changes: 2 additions & 0 deletions embeddings/evaluator/evaluation_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,5 @@ class QuestionAnsweringEvaluationResults(EvaluationResults):
NoAns_f1: Optional[float] = None
NoAns_total: Optional[float] = None
data: Optional[Data] = None
golds_text: Optional[Union[List[List[str]], List[str]]] = None
predictions_text: Optional[List[str]] = None
21 changes: 19 additions & 2 deletions embeddings/evaluator/question_answering_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,22 @@ def __init__(self, no_answer_threshold: float = 1.0):

def metrics(
self,
) -> Dict[str, Metric[Union[List[Any], nptyping.NDArray[Any], torch.Tensor], Dict[Any, Any]]]:
) -> Dict[str, Metric[Union[List[Any], nptyping.NDArray[Any], torch.Tensor], Dict[Any, Any]],]:
return {}

@staticmethod
def get_golds_text(references: List[QA_GOLD_ANSWER_TYPE]) -> Union[List[List[str]], List[str]]:
golds_text = []
for ref in references:
answers = ref["answers"]
assert isinstance(answers, dict)
golds_text.append(answers["text"])
return golds_text

@staticmethod
def get_predictions_text(predictions: List[QA_PREDICTED_ANSWER_TYPE]) -> List[str]:
return [str(it["prediction_text"]) for it in predictions]

def evaluate(
self, data: Union[Dict[str, nptyping.NDArray[Any]], Predictions, Dict[str, Any]]
) -> QuestionAnsweringEvaluationResults:
Expand All @@ -51,5 +64,9 @@ def evaluate(
{"id": it_id, **it["predicted_answer"]} for it_id, it in enumerate(outputs)
]
metrics = SQUADv2Metric().calculate(predictions=predictions, references=references)
gold_texts = QuestionAnsweringEvaluator.get_golds_text(references)
predictions_text = QuestionAnsweringEvaluator.get_predictions_text(predictions)

return QuestionAnsweringEvaluationResults(data=outputs, **metrics)
return QuestionAnsweringEvaluationResults(
data=outputs, golds_text=gold_texts, predictions_text=predictions_text, **metrics
)
17 changes: 17 additions & 0 deletions embeddings/pipeline/lightning_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Any, Dict, List, Optional, Union

import datasets
import pandas as pd
import yaml
from pytorch_lightning.accelerators import Accelerator

from embeddings.config.lightning_config import LightningQABasicConfig, LightningQAConfig
Expand All @@ -14,6 +16,7 @@
from embeddings.pipeline.lightning_pipeline import LightningPipeline
from embeddings.task.lightning_task import question_answering as qa
from embeddings.utils.loggers import LightningLoggingConfig
from embeddings.utils.utils import convert_qa_df_to_bootstrap_html


class LightningQuestionAnsweringPipeline(
Expand Down Expand Up @@ -86,3 +89,17 @@ def __init__(
logging_config,
pipeline_kwargs=pipeline_kwargs,
)

def _save_metrics(self) -> None:
metrics = getattr(self.result, "metrics")
with open(self.output_path / "metrics.yaml", "w") as f:
yaml.dump(metrics, stream=f)

predictions_text = getattr(self.result, "predictions_text")
golds_text = getattr(self.result, "open")
with open(self.output_path / "predictions.html", "w") as f:
f.write(
convert_qa_df_to_bootstrap_html(
pd.DataFrame({"predictions": predictions_text, "golds": golds_text})
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _get_predicted_text_from_context(
def _get_softmax_scores_with_sort(predictions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
scores = torch.from_numpy(np.array([pred.pop("score") for pred in predictions]))
# Module torch.functional does not explicitly export attritube "F"
softmax_scores = torch.functional.F.softmax(scores) # type: ignore[attr-defined]
softmax_scores = torch.functional.F.softmax(scores, dim=0) # type: ignore[attr-defined]
for prob, pred in zip(softmax_scores, predictions):
pred["softmax_score"] = prob
# mypy thinks the function only returns Any
Expand Down
30 changes: 30 additions & 0 deletions embeddings/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import pkg_resources
import requests
import yaml
Expand Down Expand Up @@ -152,3 +153,32 @@ def compress_and_remove(filepath: T_path) -> None:
) as arc:
arc.write(filepath, arcname=filepath.name)
filepath.unlink()


def convert_qa_df_to_bootstrap_html(df: pd.DataFrame) -> str:
boostrap_cdn = (
'<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css" '
'integrity="sha384-BVYiiSIFeK1dGmJRAkycuHAHRg32OmUcww7on3RYdg4Va+PmSTsz/K68vbdEjh4u" crossorigin="anonymous">'
)

output = (
"<!DOCTYPE html>"
+ "\n"
+ "<html>"
+ "\n"
+ "<head>"
+ "\n"
+ boostrap_cdn
+ "\n"
+ '<meta charset="utf-8">'
+ "\n"
+ "</head>"
+ "\n"
+ "<body>"
+ "\n"
+ df.to_html(classes=["table table-bordered table-striped table-hover"])
+ "\n"
+ "</body>"
)
assert isinstance(output, str)
return output

0 comments on commit 370c990

Please sign in to comment.