Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: add support for question answering benchmark #94

Merged
merged 16 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 166 additions & 35 deletions pyrit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass, field
from hashlib import sha256
from pathlib import Path
from typing import Literal, Optional, Type, TypeVar
from typing import Literal, Optional, Type, TypeVar, Union

import yaml
from pydantic import BaseModel, ConfigDict
Expand Down Expand Up @@ -95,64 +95,195 @@ class Prompt:
content: str


@dataclass
class ScoreAnswers:
answers: list[str]
name: str = ""
version: str = ""
description: str = ""
author: str = ""
group: str = ""
source: str = ""
# @dataclass
dlmgary marked this conversation as resolved.
Show resolved Hide resolved
# class ScoreAnswers:
# answers: list[str]
# name: str = ""
# version: str = ""
# description: str = ""
# author: str = ""
# group: str = ""
# source: str = ""

@staticmethod
def from_yaml(file: Path) -> ScoreAnswers:
yaml_data = yaml.safe_load(file.read_text("utf-8"))
return ScoreAnswers(**yaml_data)
# @staticmethod
# def from_yaml(file: Path) -> ScoreAnswers:
# yaml_data = yaml.safe_load(file.read_text("utf-8"))
# return ScoreAnswers(**yaml_data)


@dataclass
class ExamAnswer:
answer: str
explanation: str
confidence: str
# @dataclass
# class ExamAnswer:
# answer: str
# explanation: str
# confidence: str


@dataclass
class ExamAnswers:
answer: list[ExamAnswer] = field(default_factory=list)
# @dataclass
# class ExamAnswers:
# answer: list[ExamAnswer] = field(default_factory=list)


@dataclass
class ScoringResults:
# @dataclass
# class ScoringResults:
# failed: int
# passed: int
# # unknown: int
# questions_count: int
# passed_with_partial_credit: float


class ScoringResult(BaseModel):
model_config = ConfigDict(extra="forbid")
provided_answer: str
correct_answer: str
is_correct: bool


class ScoringResults(BaseModel):
"""
Represents the results of a scoring process.

Attributes:
failed (int): The number of failed cases.
passed (int): The number of passed cases.
"""

model_config = ConfigDict(extra="forbid")
failed: int
passed: int
# unknown: int
questions_count: int
passed_with_partial_credit: float


@dataclass
class CompletionConfig:
temperature: int
max_tokens: int
class AggregateScoringResults(BaseModel):
"""
Represents the aggregate scoring results.

Attributes:
total_failed (int): The total number of failed questions.
total_passed (int): The total number of passed questions.
total_questions_count (int): The total number of questions.

Methods:
add_results(results: ScoringResults): Adds the results of a single scoring to the aggregate results.
from_results_list(results_list: list[ScoringResults]) -> AggregateScoringResults: Creates an instance of
AggregateScoringResults from a list of ScoringResults.
"""

model_config = ConfigDict(extra="forbid")
total_failed: int = 0
total_passed: int = 0
total_questions_count: int = 0

def add_results(self, results: ScoringResults):
"""
Adds the results of a single scoring to the aggregate results.

Args:
results (ScoringResults): The scoring results to be added.
"""
self.total_failed += results.failed
self.total_passed += results.passed
self.total_questions_count += results.failed + results.passed

@classmethod
def from_results_list(cls, results_list: list[ScoringResults]) -> AggregateScoringResults:
"""
Creates an instance of AggregateScoringResults from a list of ScoringResults.

Args:
results_list (list[ScoringResults]): The list of scoring results.

Returns:
AggregateScoringResults: An instance of AggregateScoringResults with aggregated results from the list.
"""
instance = cls()
for results in results_list:
instance.add_results(results)
return instance


class QuestionChoice(BaseModel):
"""
Represents a choice for a question.

Attributes:
index (int): The index of the choice.
text (str): The text of the choice.
"""

model_config = ConfigDict(extra="forbid")
index: int
text: str


class QuestionAnsweringEntry(BaseModel):
dlmgary marked this conversation as resolved.
Show resolved Hide resolved
"""
Represents a question model.

Attributes:
question (str): The question text.
answer_type (Literal["int", "float", "str", "bool"]): The type of the answer.
- `int` for integer answers (e.g., when the answer is an index of the correct option in a multiple-choice
question).
- `float` for answers that are floating-point numbers.
- `str` for text-based answers.
- `bool` for boolean answers.
correct_answer (Union[int, str, float]): The correct answer.
choices (list[QuestionChoice]): The list of choices for the question.
"""

model_config = ConfigDict(extra="forbid")
question: str
answer_type: Literal["int", "float", "str", "bool"]
correct_answer: Union[int, str, float]
choices: list[QuestionChoice]


class QuestionAnsweringDataset(BaseModel):
"""
Represents a dataset for question answering.

Attributes:
name (str): The name of the dataset.
version (str): The version of the dataset.
description (str): A description of the dataset.
author (str): The author of the dataset.
group (str): The group associated with the dataset.
source (str): The source of the dataset.
questions (list[QuestionAnsweringEntry]): A list of question models.
"""

model_config = ConfigDict(extra="forbid")
name: str = ""
version: str = ""
description: str = ""
author: str = ""
group: str = ""
source: str = ""
questions: list[QuestionAnsweringEntry]


T = TypeVar("T", bound="YamlLoadable")


class YamlLoadable(abc.ABC):
"""
Abstract base class for objects that can be loaded from YAML files.
"""

@classmethod
def from_yaml_file(cls: Type[T], file: Path) -> T:
"""
Creates a new object from a file
Creates a new object from a YAML file.

Args:
file: The input file
file: The input file path.

Returns:
A new T object
A new object of type T.

Raises:
FileNotFoundError: if the input YAML file path does not exist
FileNotFoundError: If the input YAML file path does not exist.
ValueError: If the YAML file is invalid.
"""
if not file.exists():
raise FileNotFoundError(f"File '{file}' does not exist.")
Expand Down
Loading
Loading