Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: PromptMemoryEntry Table Added for more Extensible Target Logic #125

Merged
merged 17 commits into from
Apr 1, 2024
Merged
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ dev = [
"types-PyYAML>=6.0.12.9",
]

[tool.pytest.ini_options]
pythonpath = ["."]

[tool.mypy]
plugins = []
ignore_missing_imports = true
Expand Down
18 changes: 9 additions & 9 deletions pyrit/analytics/conversation_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from sklearn.metrics.pairwise import cosine_similarity
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.memory.memory_models import ConversationMessageWithSimilarity, EmbeddingMessageWithSimilarity
from pyrit.memory.memory_models import ConversationData, EmbeddingData


class ConversationAnalytics:
Expand All @@ -24,11 +23,11 @@ def __init__(self, *, memory_interface: MemoryInterface):
"""
self.memory_interface = memory_interface

def get_similar_chat_messages_by_content(
def get_prompt_entries_with_same_converted_content(
self, *, chat_message_content: str
) -> list[ConversationMessageWithSimilarity]:
"""
Retrieves chat messages that are similar to the given content based on exact matches.
Retrieves chat messages that have the same converted content

Args:
chat_message_content (str): The content of the chat message to find similar messages for.
Expand All @@ -37,16 +36,16 @@ def get_similar_chat_messages_by_content(
list[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing
the similar chat messages based on content.
"""
all_memories = self.memory_interface.get_all_memory(ConversationData)
all_memories = self.memory_interface.get_all_prompt_entries()
similar_messages = []

for memory in all_memories:
if memory.content == chat_message_content:
if memory.converted_prompt_text == chat_message_content:
similar_messages.append(
ConversationMessageWithSimilarity(
score=1.0,
role=memory.role,
content=memory.content,
content=memory.converted_prompt_text,
metric="exact_match", # Exact match
)
)
Expand All @@ -67,12 +66,13 @@ def get_similar_chat_messages_by_embedding(
List[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing
the similar chat messages based on embedding similarity.
"""
all_memories = self.memory_interface.get_all_memory(EmbeddingData)

all_embdedding_memory = self.memory_interface.get_all_embeddings()
similar_messages = []

target_embedding = np.array(chat_message_embedding).reshape(1, -1)

for memory in all_memories:
for memory in all_embdedding_memory:
if not hasattr(memory, "embedding") or memory.embedding is None:
continue

Expand All @@ -82,7 +82,7 @@ def get_similar_chat_messages_by_embedding(
if similarity_score >= threshold:
similar_messages.append(
EmbeddingMessageWithSimilarity(
score=similarity_score, uuid=memory.uuid, metric="cosine_similarity" # type: ignore
score=similarity_score, uuid=memory.id, metric="cosine_similarity" # type: ignore
)
)

Expand Down
4 changes: 2 additions & 2 deletions pyrit/common/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import textwrap

import termcolor
from pyrit.memory.memory_models import ConversationData
from pyrit.memory.memory_models import PromptMemoryEntry
from pyrit.models import ChatMessage
from termcolor._types import Color


def print_chat_messages_with_color(
messages: list[ChatMessage | ConversationData],
messages: list[ChatMessage | PromptMemoryEntry],
max_content_character_width: int = 80,
left_padding_width: int = 20,
custom_colors: dict[str, Color] = None,
Expand Down
4 changes: 2 additions & 2 deletions pyrit/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.memory.memory_models import ConversationData
from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData
from pyrit.memory.duckdb_memory import DuckDBMemory
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.memory.memory_embedding import MemoryEmbedding
from pyrit.memory.memory_exporter import MemoryExporter


__all__ = ["ConversationData", "MemoryInterface", "MemoryEmbedding", "DuckDBMemory", "MemoryExporter"]
__all__ = ["PromptMemoryEntry", "EmbeddingData", "MemoryInterface", "MemoryEmbedding", "DuckDBMemory", "MemoryExporter"]
183 changes: 118 additions & 65 deletions pyrit/memory/duckdb_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
from sqlalchemy.engine.base import Engine
from contextlib import closing

from pyrit.memory.memory_models import ConversationData, Base
from pyrit.memory.memory_embedding import default_memory_embedding_factory
from pyrit.memory.memory_models import EmbeddingData, PromptMemoryEntry, Base
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.interfaces import EmbeddingSupport
from pyrit.common.path import RESULTS_PATH
from pyrit.common.singleton import Singleton

Expand All @@ -33,15 +31,19 @@ class DuckDBMemory(MemoryInterface, metaclass=Singleton):
DEFAULT_DB_FILE_NAME = "pyrit_duckdb_storage.db"

def __init__(
self, *, db_path: Union[Path, str] = None, embedding_model: EmbeddingSupport = None, has_echo: bool = False
self,
*,
db_path: Union[Path, str] = None,
verbose: bool = False,
):
super(DuckDBMemory, self).__init__()
self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model)

if db_path == ":memory:":
self.db_path: Union[Path, str] = ":memory:"
else:
self.db_path = Path(db_path or Path(RESULTS_PATH, self.DEFAULT_DB_FILE_NAME)).resolve()
self.engine = self._create_engine(has_echo=has_echo)

self.engine = self._create_engine(has_echo=verbose)
self.SessionFactory = sessionmaker(bind=self.engine)
self._create_tables_if_not_exist()

Expand Down Expand Up @@ -76,6 +78,104 @@ def _create_tables_if_not_exist(self):
except Exception as e:
logger.error(f"Error during table creation: {e}")

def get_all_prompt_entries(self) -> list[PromptMemoryEntry]:
"""
Fetches all entries from the specified table and returns them as model instances.
"""
result = self.query_entries(PromptMemoryEntry)
return result

def get_all_embeddings(self) -> list[EmbeddingData]:
"""
Fetches all entries from the specified table and returns them as model instances.
"""
result = self.query_entries(EmbeddingData)
return result

def get_prompt_entries_with_conversation_id(self, *, conversation_id: str) -> list[PromptMemoryEntry]:
"""
Retrieves a list of ConversationData objects that have the specified conversation ID.

Args:
conversation_id (str): The conversation ID to filter the table.

Returns:
list[ConversationData]: A list of ConversationData objects matching the specified conversation ID.
"""
try:
return self.query_entries(
PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id
)
except Exception as e:
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
return []

def get_prompt_entries_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptMemoryEntry]:
"""
Retrieves a list of ConversationData objects that have the specified normalizer ID.

Args:
normalizer_id (str): The normalizer ID to filter the table.

Returns:
list[ConversationData]: A list of ConversationData objects matching the specified normalizer ID.
"""
try:
return self.query_entries(
PromptMemoryEntry, conditions=PromptMemoryEntry.labels.op("->>")("normalizer_id") == normalizer_id
)
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with normalizer_id {normalizer_id}. {e}"
)
return []

def insert_prompt_entries(self, *, entries: list[PromptMemoryEntry]) -> None:
"""
Inserts a list of prompt entries into the memory storage.
If necessary, generates embedding data for applicable entries

Args:
entries (list[Base]): The list of database model instances to be inserted.
"""
embedding_entries = []

if self.memory_embedding:
for chat_entry in entries:
embedding_entry = self.memory_embedding.generate_embedding_memory_data(chat_memory=chat_entry)
embedding_entries.append(embedding_entry)

# The ordering of this is weird because after memories are inserted, we lose the reference to them
# and also entries must be inserted before embeddings because of the foreing key constraint
self.insert_entries(entries=entries)

if embedding_entries:
self.insert_entries(entries=embedding_entries)

def update_entries_by_conversation_id(self, *, conversation_id: str, update_fields: dict) -> bool:
"""
Updates entries for a given conversation ID with the specified field values.

Args:
conversation_id (str): The conversation ID of the entries to be updated.
update_fields (dict): A dictionary of field names and their new values.

Returns:
bool: True if the update was successful, False otherwise.
"""
# Fetch the relevant entries using query_entries
entries_to_update = self.query_entries(
PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id
)

# Check if there are entries to update
if not entries_to_update:
logger.info(f"No entries found with conversation_id {conversation_id} to update.")
return False

# Use the utility function to update the entries
return self.update_entries(entries=entries_to_update, update_fields=update_fields)

def get_all_table_models(self) -> list[Base]: # type: ignore
"""
Returns a list of all table models used in the database by inspecting the Base registry.
Expand Down Expand Up @@ -159,70 +259,23 @@ def update_entries(self, *, entries: list[Base], update_fields: dict) -> bool:
logger.exception(f"Error updating entries: {e}")
return False

def get_all_memory(self, model: Base) -> list[Base]: # type: ignore
"""
Fetches all entries from the specified table and returns them as model instances.
def export_all_tables(self, *, export_type: str = "json"):
"""
result = self.query_entries(model)
return result
Exports all table data using the specified exporter.

def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[ConversationData]:
"""
Retrieves a list of ConversationData objects that have the specified conversation ID.
Iterates over all tables, retrieves their data, and exports each to a file named after the table.

Args:
conversation_id (str): The conversation ID to filter the table.

Returns:
list[ConversationData]: A list of ConversationData objects matching the specified conversation ID.
"""
try:
return self.query_entries(ConversationData, conditions=ConversationData.conversation_id == conversation_id)
except Exception as e:
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
return []

def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[ConversationData]:
export_type (str): The format to export the data in (defaults to "json").
"""
Retrieves a list of ConversationData objects that have the specified normalizer ID.

Args:
normalizer_id (str): The normalizer ID to filter the table.

Returns:
list[ConversationData]: A list of ConversationData objects matching the specified normalizer ID.
"""
try:
return self.query_entries(ConversationData, conditions=ConversationData.normalizer_id == normalizer_id)
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with normalizer_id {normalizer_id}. {e}"
)
return []

def update_entries_by_conversation_id(self, *, conversation_id: str, update_fields: dict) -> bool:
"""
Updates entries for a given conversation ID with the specified field values.

Args:
conversation_id (str): The conversation ID of the entries to be updated.
update_fields (dict): A dictionary of field names and their new values.

Returns:
bool: True if the update was successful, False otherwise.
"""
# Fetch the relevant entries using query_entries
entries_to_update = self.query_entries(
ConversationData, conditions=ConversationData.conversation_id == conversation_id
)

# Check if there are entries to update
if not entries_to_update:
logger.info(f"No entries found with conversation_id {conversation_id} to update.")
return False

# Use the utility function to update the entries
return self.update_entries(entries=entries_to_update, update_fields=update_fields)
table_models = self.get_all_table_models()

for model in table_models:
data = self.query_entries(model)
table_name = model.__tablename__
file_extension = f".{export_type}"
file_path = RESULTS_PATH / f"{table_name}{file_extension}"
self.exporter.export_data(data, file_path=file_path, export_type=export_type)

def dispose_engine(self):
"""
Expand Down
25 changes: 16 additions & 9 deletions pyrit/memory/memory_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from pyrit.embedding.azure_text_embedding import AzureTextEmbedding
from pyrit.interfaces import EmbeddingSupport
from pyrit.memory.memory_models import ConversationData, EmbeddingData
from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData


class MemoryEmbedding:
Expand All @@ -20,7 +20,7 @@ def __init__(self, *, embedding_model: EmbeddingSupport):
raise ValueError("embedding_model must be set.")
self.embedding_model = embedding_model

def generate_embedding_memory_data(self, *, chat_memory: ConversationData) -> EmbeddingData:
def generate_embedding_memory_data(self, *, chat_memory: PromptMemoryEntry) -> EmbeddingData:
"""
Generates metadata for a chat memory entry.

Expand All @@ -30,12 +30,17 @@ def generate_embedding_memory_data(self, *, chat_memory: ConversationData) -> Em
Returns:
ConversationMemoryEntryMetadata: The generated metadata.
"""
embedding_data = EmbeddingData(
embedding=self.embedding_model.generate_text_embedding(text=chat_memory.content).data[0].embedding,
embedding_type_name=self.embedding_model.__class__.__name__,
uuid=chat_memory.uuid,
)
return embedding_data
if chat_memory.converted_prompt_data_type == "text":
embedding_data = EmbeddingData(
embedding=self.embedding_model.generate_text_embedding(text=chat_memory.converted_prompt_text)
.data[0]
.embedding,
embedding_type_name=self.embedding_model.__class__.__name__,
id=chat_memory.id,
)
return embedding_data

raise ValueError("Only text data is supported for embedding.")


def default_memory_embedding_factory(embedding_model: EmbeddingSupport = None) -> MemoryEmbedding | None:
Expand All @@ -49,4 +54,6 @@ def default_memory_embedding_factory(embedding_model: EmbeddingSupport = None) -
model = AzureTextEmbedding(api_key=api_key, endpoint=api_base, deployment=deployment)
return MemoryEmbedding(embedding_model=model)
else:
return None
raise ValueError(
"No embedding model was provided and no Azure OpenAI embedding model was found in the environment."
)
Loading
Loading