Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Score Prompts Orchestrator #208

Merged
merged 8 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions pyrit/memory/duckdb_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,27 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
return []

def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestPiece objects that have the specified prompt ids.

Args:
prompt_ids (list[str]): The prompt IDs to match.

Returns:
list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
"""
try:
return self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.id.in_(prompt_ids),
)
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with orchestrator {prompt_ids}. {e}"
)
return []

def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: int) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestPiece objects that have the specified orchestrator ID.
Expand Down
26 changes: 19 additions & 7 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: int) -> list[Pr
"""

@abc.abstractmethod
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
def add_request_pieces_to_memory(self, *, request_pieces: list[PromptRequestPiece]) -> None:
"""
Inserts embedding data into memory storage
Inserts a list of prompt request pieces into the memory storage.
"""

@abc.abstractmethod
def add_request_pieces_to_memory(self, *, request_pieces: list[PromptRequestPiece]) -> None:
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
"""
Inserts a list of prompt request pieces into the memory storage.
Inserts embedding data into memory storage
"""

@abc.abstractmethod
Expand Down Expand Up @@ -115,9 +115,21 @@ def get_conversation(self, *, conversation_id: str) -> list[PromptRequestRespons
request_pieces = self._get_prompt_pieces_with_conversation_id(conversation_id=conversation_id)
return group_conversation_request_pieces_by_sequence(request_pieces=request_pieces)

def get_orchestrator_conversations(self, *, orchestrator_id: int) -> list[PromptRequestPiece]:
@abc.abstractmethod
def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestPiece objects that have the specified prompt ids.

Args:
prompt_ids (list[int]): The prompt IDs to match.

Returns:
list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
"""

def get_prompt_request_piece_by_orchestrator_id(self, *, orchestrator_id: int) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestResponse objects that have the specified orchestrator ID.
Retrieves a list of PromptRequestPiece objects that have the specified orchestrator ID.

Args:
orchestrator_id (int): The orchestrator ID to match.
Expand Down Expand Up @@ -175,7 +187,7 @@ def export_conversation_by_orchestrator_id(
If not provided, a default path using RESULTS_PATH will be constructed.
export_type (str): The format of the export. Defaults to "json".
"""
data = self.get_orchestrator_conversations(orchestrator_id=orchestrator_id)
data = self._get_prompt_pieces_by_orchestrator(orchestrator_id=orchestrator_id)

# If file_path is not provided, construct a default using the exporter's results_path
if not file_path:
Expand Down
2 changes: 1 addition & 1 deletion pyrit/orchestrator/orchestrator_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_memory(self):
"""
Retrieves the memory associated with this orchestrator.
"""
return self._memory.get_orchestrator_conversations(orchestrator_id=id(self))
return self._memory.get_prompt_request_piece_by_orchestrator_id(orchestrator_id=id(self))

def get_identifier(self) -> dict[str, str]:
orchestrator_dict = {}
Expand Down
90 changes: 90 additions & 0 deletions pyrit/orchestrator/scoring_orchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import logging

from pyrit.memory import MemoryInterface
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.models.score import Score
from pyrit.orchestrator import Orchestrator
from pyrit.prompt_normalizer import PromptNormalizer
from pyrit.score.scorer import Scorer

logger = logging.getLogger(__name__)


class ScoringOrchestrator(Orchestrator):
"""
This orchestrator scores prompts in a parallizable and convenient way.
"""

def __init__(
self,
memory: MemoryInterface = None,
batch_size: int = 10,
verbose: bool = False,
) -> None:
"""
Args:
memory (MemoryInterface, optional): The memory interface. Defaults to None.
batch_size (int, optional): The (max) batch size for sending prompts. Defaults to 10.
"""
super().__init__(memory=memory, verbose=verbose)

self._prompt_normalizer = PromptNormalizer(memory=self._memory)
self._batch_size = batch_size

async def score_prompts_by_orchestrator_id_async(
self,
*,
scorer: Scorer,
orchestrator_ids: list[int],
responses_only: bool = True,
) -> list[Score]:
"""
Scores prompts using the Scorer for prompts correlated to the orchestrator_ids.
"""

request_pieces: list[PromptRequestPiece] = []
for id in orchestrator_ids:
request_pieces.extend(self._memory.get_prompt_request_piece_by_orchestrator_id(orchestrator_id=id))

if responses_only:
request_pieces = self._extract_responses_only(request_pieces)

return await self._score_prompts_batch_async(prompts=request_pieces, scorer=scorer)

async def score_prompts_by_request_id_async(self, *, scorer: Scorer, prompt_ids: list[str]) -> list[Score]:
"""
Scores prompts using the Scorer for prompts with the prompt_ids
"""

requests: list[PromptRequestPiece] = []
requests = self._memory.get_prompt_request_pieces_by_id(prompt_ids=prompt_ids)

return await self._score_prompts_batch_async(prompts=requests, scorer=scorer)

def _extract_responses_only(self, request_responses: list[PromptRequestPiece]) -> list[PromptRequestPiece]:
"""
Extracts the responses from the list of PromptRequestResponse objects.
"""
return [response for response in request_responses if response.role == "assistant"]

async def _score_prompts_batch_async(self, prompts: list[PromptRequestPiece], scorer: Scorer) -> list[Score]:
results = []

for prompts_batch in self._chunked_prompts(prompts, self._batch_size):
tasks = []
for prompt in prompts_batch:
tasks.append(scorer.score_async(request_response=prompt))

batch_results = await asyncio.gather(*tasks)
results.extend(batch_results)

# results is a list[list[str]] and needs to be flattened
return [score for sublist in results for score in sublist]

def _chunked_prompts(self, prompts, size):
for i in range(0, len(prompts), size):
yield prompts[i : i + size]
Loading
Loading