Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/Azure/PyRIT into releases/v…
Browse files Browse the repository at this point in the history
…0.2.1
  • Loading branch information
romanlutz committed May 1, 2024
2 parents 7b94e9e + c62b24a commit 9e852f1
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyrit/memory/duckdb_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_all_prompt_pieces(self) -> list[PromptRequestPiece]:
Fetches all entries from the specified table and returns them as model instances.
"""
result: list[PromptMemoryEntry] = self.query_entries(PromptMemoryEntry)
return [entry.get_prompt_reqest_piece() for entry in result]
return [entry.get_prompt_request_piece() for entry in result]

def get_all_embeddings(self) -> list[EmbeddingData]:
"""
Expand Down
22 changes: 22 additions & 0 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,28 @@ def get_orchestrator_conversations(self, *, orchestrator_id: int) -> list[Prompt
prompt_pieces = self._get_prompt_pieces_by_orchestrator(orchestrator_id=orchestrator_id)
return sorted(prompt_pieces, key=lambda x: (x.conversation_id, x.timestamp))

def export_conversation_by_orchestrator_id(
self, *, orchestrator_id: int, file_path: Path = None, export_type: str = "json"
):
"""
Exports conversation data with the given orchestrator ID to a specified file.
This will contain all conversations that were sent by the same orchestrator.
Args:
orchestrator_id (str): The ID of the orchestrator from which to export conversations.
file_path (str): The path to the file where the data will be exported.
If not provided, a default path using RESULTS_PATH will be constructed.
export_type (str): The format of the export. Defaults to "json".
"""
data = self.get_orchestrator_conversations(orchestrator_id=orchestrator_id)

# If file_path is not provided, construct a default using the exporter's results_path
if not file_path:
file_name = f"{str(orchestrator_id)}.{export_type}"
file_path = RESULTS_PATH / file_name

self.exporter.export_data(data, file_path=file_path, export_type=export_type)

def add_request_response_to_memory(self, *, request: PromptRequestResponse) -> None:
"""
Inserts a list of prompt request pieces into the memory storage.
Expand Down
2 changes: 1 addition & 1 deletion pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, *, entry: PromptRequestPiece):

self.response_error = entry.response_error

def get_prompt_reqest_piece(self) -> PromptRequestPiece:
def get_prompt_request_piece(self) -> PromptRequestPiece:
return PromptRequestPiece(
role=self.role,
original_value=self.original_value,
Expand Down
33 changes: 30 additions & 3 deletions tests/memory/test_memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@

from typing import Generator
from unittest.mock import MagicMock, patch
from pathlib import Path
import pytest
import random
from string import ascii_lowercase

from pyrit.common.path import RESULTS_PATH
from pyrit.memory import MemoryInterface
from pyrit.memory.memory_models import PromptRequestPiece
from pyrit.memory.memory_exporter import MemoryExporter
from pyrit.memory.memory_models import PromptRequestPiece, PromptMemoryEntry
from pyrit.models import PromptRequestResponse

from tests.mocks import get_memory_interface, get_sample_conversations
from tests.mocks import get_memory_interface, get_sample_conversations, get_sample_conversation_entries


@pytest.fixture
Expand All @@ -24,6 +27,11 @@ def sample_conversations() -> list[PromptRequestPiece]:
return get_sample_conversations()


@pytest.fixture
def sample_conversation_entries() -> list[PromptMemoryEntry]:
return get_sample_conversation_entries()


def generate_random_string(length: int = 10) -> str:
return "".join(random.choice(ascii_lowercase) for _ in range(length))

Expand Down Expand Up @@ -161,7 +169,7 @@ def test_insert_prompt_memories_not_inserts_embedding(


def test_get_orchestrator_conversation_sorting(memory: MemoryInterface, sample_conversations: list[PromptRequestPiece]):
conversation_id = sample_conversations[0].orchestrator_identifier["id"]
conversation_id = sample_conversations[0].conversation_id

# This new conversation piece should be grouped with other messages in the conversation
sample_conversations.append(
Expand All @@ -185,3 +193,22 @@ def test_get_orchestrator_conversation_sorting(memory: MemoryInterface, sample_c
if new_value != current_value:
if any(o.conversation_id == current_value for o in response[response.index(obj) :]):
assert False, "Conversation IDs are not grouped together"


def test_export_conversation_by_orchestrator_id_file_created(
memory: MemoryInterface, sample_conversation_entries: list[PromptMemoryEntry]
):
orchestrator1_id = sample_conversation_entries[0].get_prompt_request_piece().orchestrator_identifier["id"]

# Default path in export_conversation_by_orchestrator_id()
file_name = f"{str(orchestrator1_id)}.json"
file_path = Path(RESULTS_PATH, file_name)

memory.exporter = MemoryExporter()

with patch("pyrit.memory.duckdb_memory.DuckDBMemory._get_prompt_pieces_by_orchestrator") as mock_get:
mock_get.return_value = sample_conversation_entries
memory.export_conversation_by_orchestrator_id(orchestrator_id=int(orchestrator1_id))

# Verify file was created
assert file_path.exists()

0 comments on commit 9e852f1

Please sign in to comment.