Skip to content

Commit

Permalink
FEAT support duplicating memory when cloning orchestrators (Azure#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
romanlutz authored and QDAP-Fred committed May 13, 2024
1 parent 71cd200 commit 2e5afc2
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 1 deletion.
36 changes: 36 additions & 0 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import abc
from pathlib import Path

from typing import Optional
from uuid import uuid4

from pyrit.memory.memory_models import EmbeddingData
from pyrit.models import PromptRequestResponse, Score, PromptRequestPiece, PromptResponseError, PromptDataType

Expand Down Expand Up @@ -126,6 +129,39 @@ def get_orchestrator_conversations(self, *, orchestrator_id: int) -> list[Prompt
prompt_pieces = self._get_prompt_pieces_by_orchestrator(orchestrator_id=orchestrator_id)
return sorted(prompt_pieces, key=lambda x: (x.conversation_id, x.timestamp))

def duplicate_conversation_for_new_orchestrator(
self,
*,
new_orchestrator_id: str,
conversation_id: str,
new_conversation_id: Optional[str] = None,
) -> None:
"""
Duplicates a conversation from one orchestrator to another.
This can be useful when an attack strategy requires branching out from a particular point in the conversation.
One cannot continue both branches with the same orchestrator and conversation IDs since that would corrupt
the memory. Instead, one needs to duplicate the conversation and continue with the new orchestrator ID.
Args:
new_orchestrator_id (str): The new orchestrator ID to assign to the duplicated conversations.
conversation_id (str): The conversation ID with existing conversations.
new_conversation_id (str): The new conversation ID to assign to the duplicated conversations.
If no new_conversation_id is provided, a new one will be generated.
"""
new_conversation_id = new_conversation_id or str(uuid4())
if conversation_id == new_conversation_id:
raise ValueError("The new conversation ID must be different from the existing conversation ID.")
prompt_pieces = self._get_prompt_pieces_with_conversation_id(conversation_id=conversation_id)
for piece in prompt_pieces:
piece.id = uuid4()
if piece.orchestrator_identifier["id"] == new_orchestrator_id:
raise ValueError("The new orchestrator ID must be different from the existing orchestrator ID.")
piece.orchestrator_identifier["id"] = new_orchestrator_id
piece.conversation_id = new_conversation_id

self._add_request_pieces_to_memory(request_pieces=prompt_pieces)

def export_conversation_by_orchestrator_id(
self, *, orchestrator_id: int, file_path: Path = None, export_type: str = "json"
):
Expand Down
2 changes: 1 addition & 1 deletion pyrit/orchestrator/orchestrator_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_memory(self):
"""
return self._memory.get_orchestrator_conversations(orchestrator_id=id(self))

def get_identifier(self):
def get_identifier(self) -> dict[str, str]:
orchestrator_dict = {}
orchestrator_dict["__type__"] = self.__class__.__name__
orchestrator_dict["__module__"] = self.__class__.__module__
Expand Down
125 changes: 125 additions & 0 deletions tests/memory/test_memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pyrit.memory.memory_exporter import MemoryExporter
from pyrit.memory.memory_models import PromptRequestPiece, PromptMemoryEntry
from pyrit.models import PromptRequestResponse
from pyrit.orchestrator import Orchestrator

from tests.mocks import get_memory_interface, get_sample_conversations, get_sample_conversation_entries

Expand Down Expand Up @@ -60,6 +61,130 @@ def test_add_request_pieces_to_memory(
assert len(memory.get_all_prompt_pieces()) == num_conversations


@pytest.mark.parametrize("new_conversation_id1", [None, "12345"])
@pytest.mark.parametrize("new_conversation_id2", [None, "23456"])
def test_duplicate_memory(memory: MemoryInterface, new_conversation_id1: str | None, new_conversation_id2: str | None):
orchestrator1 = Orchestrator()
orchestrator2 = Orchestrator()
conversation_id_1 = "11111"
conversation_id_2 = "22222"
conversation_id_3 = "33333"
pieces = [
PromptRequestPiece(
role="user",
original_value="original prompt text",
converted_value="Hello, how are you?",
conversation_id=conversation_id_1,
sequence=0,
orchestrator_identifier=orchestrator1.get_identifier(),
),
PromptRequestPiece(
role="assistant",
original_value="original prompt text",
converted_value="I'm fine, thank you!",
conversation_id=conversation_id_1,
sequence=0,
orchestrator_identifier=orchestrator1.get_identifier(),
),
PromptRequestPiece(
role="assistant",
original_value="original prompt text",
converted_value="I'm fine, thank you!",
conversation_id=conversation_id_3,
orchestrator_identifier=orchestrator2.get_identifier(),
),
PromptRequestPiece(
role="user",
original_value="original prompt text",
converted_value="Hello, how are you?",
conversation_id=conversation_id_2,
sequence=0,
orchestrator_identifier=orchestrator1.get_identifier(),
),
PromptRequestPiece(
role="assistant",
original_value="original prompt text",
converted_value="I'm fine, thank you!",
conversation_id=conversation_id_2,
sequence=0,
orchestrator_identifier=orchestrator1.get_identifier(),
),
]
memory._add_request_pieces_to_memory(request_pieces=pieces)
assert len(memory.get_all_prompt_pieces()) == 5
orchestrator3 = Orchestrator()
memory.duplicate_conversation_for_new_orchestrator(
new_orchestrator_id=orchestrator3.get_identifier()["id"],
conversation_id=conversation_id_1,
new_conversation_id=new_conversation_id1,
)
memory.duplicate_conversation_for_new_orchestrator(
new_orchestrator_id=orchestrator3.get_identifier()["id"],
conversation_id=conversation_id_2,
new_conversation_id=new_conversation_id2,
)
all_pieces = memory.get_all_prompt_pieces()
assert len(all_pieces) == 9
assert len([p for p in all_pieces if p.orchestrator_identifier["id"] == orchestrator1.get_identifier()["id"]]) == 4
assert len([p for p in all_pieces if p.orchestrator_identifier["id"] == orchestrator2.get_identifier()["id"]]) == 1
assert len([p for p in all_pieces if p.orchestrator_identifier["id"] == orchestrator3.get_identifier()["id"]]) == 4
assert len([p for p in all_pieces if p.conversation_id == conversation_id_1]) == 2
assert len([p for p in all_pieces if p.conversation_id == conversation_id_2]) == 2
assert len([p for p in all_pieces if p.conversation_id == conversation_id_3]) == 1
if new_conversation_id1:
assert len([p for p in all_pieces if p.conversation_id == new_conversation_id1]) == 2
if new_conversation_id2:
assert len([p for p in all_pieces if p.conversation_id == new_conversation_id2]) == 2


def test_duplicate_memory_conversation_id_collision(memory: MemoryInterface):
orchestrator1 = Orchestrator()
orchestrator2 = Orchestrator()
conversation_id_1 = "11111"
pieces = [
PromptRequestPiece(
role="user",
original_value="original prompt text",
converted_value="Hello, how are you?",
conversation_id=conversation_id_1,
sequence=0,
orchestrator_identifier=orchestrator1.get_identifier(),
),
]
memory._add_request_pieces_to_memory(request_pieces=pieces)
assert len(memory.get_all_prompt_pieces()) == 1
with pytest.raises(ValueError):
memory.duplicate_conversation_for_new_orchestrator(
new_orchestrator_id=str(orchestrator2.get_identifier()["id"]),
conversation_id=conversation_id_1,
new_conversation_id=conversation_id_1,
)


def test_duplicate_memory_orchestrator_id_collision(memory: MemoryInterface):
orchestrator1 = Orchestrator()
conversation_id_1 = "11111"
conversation_id_2 = "22222"
pieces = [
PromptRequestPiece(
role="user",
original_value="original prompt text",
converted_value="Hello, how are you?",
conversation_id=conversation_id_1,
sequence=0,
orchestrator_identifier=orchestrator1.get_identifier(),
),
]
memory._add_request_pieces_to_memory(request_pieces=pieces)
assert len(memory.get_all_prompt_pieces()) == 1
with pytest.raises(ValueError):
memory.duplicate_conversation_for_new_orchestrator(
new_orchestrator_id=str(orchestrator1.get_identifier()["id"]),
conversation_id=conversation_id_1,
new_conversation_id=conversation_id_2,
)


def test_add_request_pieces_to_memory_calls_validate(memory: MemoryInterface):
request_response = MagicMock(PromptRequestResponse)
request_response.request_pieces = [MagicMock(PromptRequestPiece)]
Expand Down

0 comments on commit 2e5afc2

Please sign in to comment.