Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT add CSV support #197

Merged
merged 5 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion pyrit/memory/memory_exporter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import csv
import json
from typing import Any
import uuid
from datetime import datetime
from pathlib import Path
from collections.abc import MutableMapping

from sqlalchemy.inspection import inspect

Expand All @@ -21,7 +24,8 @@ def __init__(self):
# Using strategy design pattern for export functionality.
self.export_strategies = {
"json": self.export_to_json,
# Future formats can be added here, e.g., "csv": self._export_to_csv
"csv": self.export_to_csv,
# Future formats can be added here
}

def export_data(self, data: list[Base], *, file_path: Path = None, export_type: str = "json"): # type: ignore
Expand Down Expand Up @@ -65,6 +69,31 @@ def export_to_json(self, data: list[Base], file_path: Path = None) -> None: # t
with open(file_path, "w") as f:
json.dump(export_data, f, indent=4)

def export_to_csv(self, data: list[Base], file_path: Path = None) -> None: # type: ignore
"""
Exports the provided data to a CSV file at the specified file path.
Each item in the data list, representing a row from the table,
is converted to a dictionary before being written to the file.

Args:
data (list[Base]): The data to be exported, as a list of SQLAlchemy model instances.
file_path (Path): The full path, including the file name, where the data will be exported.

Raises:
ValueError: If no file_path is provided.
"""
if not file_path:
raise ValueError("Please provide a valid file path for exporting data.")
if not data:
raise ValueError("No data to export.")

export_data = [_flatten_dict(self.model_to_dict(instance)) for instance in data]
fieldnames = list(export_data[0].keys())
with open(file_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(export_data)

def model_to_dict(self, model_instance: Base): # type: ignore
"""
Converts an SQLAlchemy model instance into a dictionary, serializing
Expand All @@ -89,3 +118,14 @@ def model_to_dict(self, model_instance: Base): # type: ignore
else:
model_dict[column.name] = value
return model_dict


def _flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> MutableMapping:
items: list[tuple[Any, Any]] = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, MutableMapping):
items.extend(_flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
80 changes: 51 additions & 29 deletions tests/memory/test_memory_exporter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import csv
import json
import pytest

Expand All @@ -21,48 +22,69 @@ def model_to_dict(instance):
return {c.key: getattr(instance, c.key) for c in inspect(instance).mapper.column_attrs}


def test_export_to_json_creates_file(tmp_path, sample_conversation_entries):
def read_file(file_path, export_type):
if export_type == "json":
with open(file_path, "r") as f:
return json.load(f)
elif export_type == "csv":
with open(file_path, "r", newline="") as f:
reader = csv.DictReader(f)
return [row for row in reader]
else:
raise ValueError(f"Invalid export type: {export_type}")


def export(export_type, exporter, data, file_path):
if export_type == "json":
exporter.export_to_json(data, file_path)
elif export_type == "csv":
exporter.export_to_csv(data, file_path)
else:
raise ValueError(f"Invalid export type: {export_type}")


@pytest.mark.parametrize("export_type", ["json", "csv"])
def test_export_to_json_creates_file(tmp_path, sample_conversation_entries, export_type):
exporter = MemoryExporter()
file_path = tmp_path / "conversations.json"
file_path = tmp_path / f"conversations.{export_type}"

exporter.export_to_json(sample_conversation_entries, file_path)
export(export_type=export_type, exporter=exporter, data=sample_conversation_entries, file_path=file_path)

assert file_path.exists() # Check that the file was created
with open(file_path, "r") as f:
content = json.load(f)
# Perform more detailed checks on content if necessary
assert len(content) == 3 # Simple check for the number of items
# Convert each ConversationStore instance to a dictionary
expected_content = [model_to_dict(conv) for conv in sample_conversation_entries]

for expected, actual in zip(expected_content, content):
assert expected["role"] == actual["role"]
assert expected["converted_value"] == actual["converted_value"]
assert expected["conversation_id"] == actual["conversation_id"]
assert expected["original_value_data_type"] == actual["original_value_data_type"]
assert expected["original_value"] == actual["original_value"]


def test_export_data_with_conversations(tmp_path, sample_conversation_entries):
content = read_file(file_path=file_path, export_type=export_type)
# Perform more detailed checks on content if necessary
assert len(content) == 3 # Simple check for the number of items
# Convert each ConversationStore instance to a dictionary
expected_content = [model_to_dict(conv) for conv in sample_conversation_entries]

for expected, actual in zip(expected_content, content):
assert expected["role"] == actual["role"]
assert expected["converted_value"] == actual["converted_value"]
assert expected["conversation_id"] == actual["conversation_id"]
assert expected["original_value_data_type"] == actual["original_value_data_type"]
assert expected["original_value"] == actual["original_value"]


@pytest.mark.parametrize("export_type", ["json", "csv"])
def test_export_to_json_data_with_conversations(tmp_path, sample_conversation_entries, export_type):
exporter = MemoryExporter()
conversation_id = sample_conversation_entries[0].conversation_id

# Define the file path using tmp_path
file_path = tmp_path / "exported_conversations.json"

# Call the method under test
exporter.export_data(sample_conversation_entries, file_path=file_path, export_type="json")
export(export_type=export_type, exporter=exporter, data=sample_conversation_entries, file_path=file_path)

# Verify the file was created
assert file_path.exists()

# Read the file and verify its contents
with open(file_path, "r") as f:
content = json.load(f)
assert len(content) == 3 # Check for the expected number of items
assert content[0]["role"] == "user"
assert content[0]["converted_value"] == "Hello, how are you?"
assert content[0]["conversation_id"] == conversation_id
assert content[1]["role"] == "assistant"
assert content[1]["converted_value"] == "I'm fine, thank you!"
assert content[1]["conversation_id"] == conversation_id
content = read_file(file_path=file_path, export_type=export_type)
assert len(content) == 3 # Check for the expected number of items
assert content[0]["role"] == "user"
assert content[0]["converted_value"] == "Hello, how are you?"
assert content[0]["conversation_id"] == conversation_id
assert content[1]["role"] == "assistant"
assert content[1]["converted_value"] == "I'm fine, thank you!"
assert content[1]["conversation_id"] == conversation_id
Loading