From f8ea6456db6efd66005994de1522b1e342af65c4 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Thu, 9 May 2024 16:56:14 -0700 Subject: [PATCH 1/2] add CSV support --- pyrit/memory/memory_exporter.py | 42 ++++++++++++++- tests/memory/test_memory_exporter.py | 80 ++++++++++++++++++---------- 2 files changed, 92 insertions(+), 30 deletions(-) diff --git a/pyrit/memory/memory_exporter.py b/pyrit/memory/memory_exporter.py index 7147897b5..058a72cad 100644 --- a/pyrit/memory/memory_exporter.py +++ b/pyrit/memory/memory_exporter.py @@ -1,10 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import csv import json import uuid from datetime import datetime from pathlib import Path +from collections.abc import MutableMapping from sqlalchemy.inspection import inspect @@ -21,7 +23,8 @@ def __init__(self): # Using strategy design pattern for export functionality. self.export_strategies = { "json": self.export_to_json, - # Future formats can be added here, e.g., "csv": self._export_to_csv + "csv": self.export_to_csv + # Future formats can be added here } def export_data(self, data: list[Base], *, file_path: Path = None, export_type: str = "json"): # type: ignore @@ -64,6 +67,31 @@ def export_to_json(self, data: list[Base], file_path: Path = None) -> None: # t export_data = [self.model_to_dict(instance) for instance in data] with open(file_path, "w") as f: json.dump(export_data, f, indent=4) + + def export_to_csv(self, data: list[Base], file_path: Path = None) -> None: # type: ignore + """ + Exports the provided data to a CSV file at the specified file path. + Each item in the data list, representing a row from the table, + is converted to a dictionary before being written to the file. + + Args: + data (list[Base]): The data to be exported, as a list of SQLAlchemy model instances. + file_path (Path): The full path, including the file name, where the data will be exported. + + Raises: + ValueError: If no file_path is provided. + """ + if not file_path: + raise ValueError("Please provide a valid file path for exporting data.") + if not data: + raise ValueError("No data to export.") + + export_data = [_flatten_dict(self.model_to_dict(instance)) for instance in data] + fieldnames = list(export_data[0].keys()) + with open(file_path, "w", newline='') as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(export_data) def model_to_dict(self, model_instance: Base): # type: ignore """ @@ -89,3 +117,15 @@ def model_to_dict(self, model_instance: Base): # type: ignore else: model_dict[column.name] = value return model_dict + + + +def _flatten_dict(d: MutableMapping, parent_key: str = '', sep: str = '.') -> MutableMapping: + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, MutableMapping): + items.extend(_flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) diff --git a/tests/memory/test_memory_exporter.py b/tests/memory/test_memory_exporter.py index 380334460..94cb00dca 100644 --- a/tests/memory/test_memory_exporter.py +++ b/tests/memory/test_memory_exporter.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import csv import json import pytest @@ -21,29 +22,51 @@ def model_to_dict(instance): return {c.key: getattr(instance, c.key) for c in inspect(instance).mapper.column_attrs} -def test_export_to_json_creates_file(tmp_path, sample_conversation_entries): +def read_file(file_path, export_type): + if export_type == "json": + with open(file_path, "r") as f: + return json.load(f) + elif export_type == "csv": + with open(file_path, "r", newline='') as f: + reader = csv.DictReader(f) + return [row for row in reader] + else: + raise ValueError(f"Invalid export type: {export_type}") + + +def export(export_type, exporter, data, file_path): + if export_type == "json": + exporter.export_to_json(data, file_path) + elif export_type == "csv": + exporter.export_to_csv(data, file_path) + else: + raise ValueError(f"Invalid export type: {export_type}") + + +@pytest.mark.parametrize("export_type", ["json", "csv"]) +def test_export_to_json_creates_file(tmp_path, sample_conversation_entries, export_type): exporter = MemoryExporter() - file_path = tmp_path / "conversations.json" + file_path = tmp_path / f"conversations.{export_type}" - exporter.export_to_json(sample_conversation_entries, file_path) + export(export_type=export_type, exporter=exporter, data=sample_conversation_entries, file_path=file_path) assert file_path.exists() # Check that the file was created - with open(file_path, "r") as f: - content = json.load(f) - # Perform more detailed checks on content if necessary - assert len(content) == 3 # Simple check for the number of items - # Convert each ConversationStore instance to a dictionary - expected_content = [model_to_dict(conv) for conv in sample_conversation_entries] - - for expected, actual in zip(expected_content, content): - assert expected["role"] == actual["role"] - assert expected["converted_value"] == actual["converted_value"] - assert expected["conversation_id"] == actual["conversation_id"] - assert expected["original_value_data_type"] == actual["original_value_data_type"] - assert expected["original_value"] == actual["original_value"] - - -def test_export_data_with_conversations(tmp_path, sample_conversation_entries): + content = read_file(file_path=file_path, export_type=export_type) + # Perform more detailed checks on content if necessary + assert len(content) == 3 # Simple check for the number of items + # Convert each ConversationStore instance to a dictionary + expected_content = [model_to_dict(conv) for conv in sample_conversation_entries] + + for expected, actual in zip(expected_content, content): + assert expected["role"] == actual["role"] + assert expected["converted_value"] == actual["converted_value"] + assert expected["conversation_id"] == actual["conversation_id"] + assert expected["original_value_data_type"] == actual["original_value_data_type"] + assert expected["original_value"] == actual["original_value"] + + +@pytest.mark.parametrize("export_type", ["json", "csv"]) +def test_export_to_json_data_with_conversations(tmp_path, sample_conversation_entries, export_type): exporter = MemoryExporter() conversation_id = sample_conversation_entries[0].conversation_id @@ -51,18 +74,17 @@ def test_export_data_with_conversations(tmp_path, sample_conversation_entries): file_path = tmp_path / "exported_conversations.json" # Call the method under test - exporter.export_data(sample_conversation_entries, file_path=file_path, export_type="json") + export(export_type=export_type, exporter=exporter, data=sample_conversation_entries, file_path=file_path) # Verify the file was created assert file_path.exists() # Read the file and verify its contents - with open(file_path, "r") as f: - content = json.load(f) - assert len(content) == 3 # Check for the expected number of items - assert content[0]["role"] == "user" - assert content[0]["converted_value"] == "Hello, how are you?" - assert content[0]["conversation_id"] == conversation_id - assert content[1]["role"] == "assistant" - assert content[1]["converted_value"] == "I'm fine, thank you!" - assert content[1]["conversation_id"] == conversation_id + content = read_file(file_path=file_path, export_type=export_type) + assert len(content) == 3 # Check for the expected number of items + assert content[0]["role"] == "user" + assert content[0]["converted_value"] == "Hello, how are you?" + assert content[0]["conversation_id"] == conversation_id + assert content[1]["role"] == "assistant" + assert content[1]["converted_value"] == "I'm fine, thank you!" + assert content[1]["conversation_id"] == conversation_id From e874054def4bae49a9f1db524d041f1599492fc6 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Fri, 10 May 2024 12:52:42 -0700 Subject: [PATCH 2/2] linting and typing fixes --- pyrit/memory/memory_exporter.py | 12 ++++++------ tests/memory/test_memory_exporter.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pyrit/memory/memory_exporter.py b/pyrit/memory/memory_exporter.py index 058a72cad..b31bdfd74 100644 --- a/pyrit/memory/memory_exporter.py +++ b/pyrit/memory/memory_exporter.py @@ -3,6 +3,7 @@ import csv import json +from typing import Any import uuid from datetime import datetime from pathlib import Path @@ -23,7 +24,7 @@ def __init__(self): # Using strategy design pattern for export functionality. self.export_strategies = { "json": self.export_to_json, - "csv": self.export_to_csv + "csv": self.export_to_csv, # Future formats can be added here } @@ -67,7 +68,7 @@ def export_to_json(self, data: list[Base], file_path: Path = None) -> None: # t export_data = [self.model_to_dict(instance) for instance in data] with open(file_path, "w") as f: json.dump(export_data, f, indent=4) - + def export_to_csv(self, data: list[Base], file_path: Path = None) -> None: # type: ignore """ Exports the provided data to a CSV file at the specified file path. @@ -88,7 +89,7 @@ def export_to_csv(self, data: list[Base], file_path: Path = None) -> None: # ty export_data = [_flatten_dict(self.model_to_dict(instance)) for instance in data] fieldnames = list(export_data[0].keys()) - with open(file_path, "w", newline='') as f: + with open(file_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(export_data) @@ -119,9 +120,8 @@ def model_to_dict(self, model_instance: Base): # type: ignore return model_dict - -def _flatten_dict(d: MutableMapping, parent_key: str = '', sep: str = '.') -> MutableMapping: - items = [] +def _flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> MutableMapping: + items: list[tuple[Any, Any]] = [] for k, v in d.items(): new_key = parent_key + sep + k if parent_key else k if isinstance(v, MutableMapping): diff --git a/tests/memory/test_memory_exporter.py b/tests/memory/test_memory_exporter.py index 94cb00dca..a565c8b20 100644 --- a/tests/memory/test_memory_exporter.py +++ b/tests/memory/test_memory_exporter.py @@ -27,12 +27,12 @@ def read_file(file_path, export_type): with open(file_path, "r") as f: return json.load(f) elif export_type == "csv": - with open(file_path, "r", newline='') as f: + with open(file_path, "r", newline="") as f: reader = csv.DictReader(f) return [row for row in reader] else: raise ValueError(f"Invalid export type: {export_type}") - + def export(export_type, exporter, data, file_path): if export_type == "json": @@ -41,7 +41,7 @@ def export(export_type, exporter, data, file_path): exporter.export_to_csv(data, file_path) else: raise ValueError(f"Invalid export type: {export_type}") - + @pytest.mark.parametrize("export_type", ["json", "csv"]) def test_export_to_json_creates_file(tmp_path, sample_conversation_entries, export_type):