-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
311 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
import sqlite3 | ||
from typing import Any, List, Optional, Tuple | ||
|
||
__all__ = ["SQLiteDB"] | ||
|
||
|
||
class SQLiteDB: | ||
""" | ||
A class for interacting with an SQLite3 database. | ||
Parameters: | ||
db_name (str): The name of the SQLite database file. | ||
Attributes: | ||
db_name (str): The name of the SQLite database file. | ||
connection (sqlite3.Connection): The connection object for the database. | ||
""" | ||
|
||
def __init__(self, db_name: str) -> None: | ||
self.db_name = db_name | ||
self.connection: Optional[sqlite3.Connection] = None | ||
|
||
def connect(self) -> None: | ||
""" | ||
Connects to the SQLite database. | ||
""" | ||
|
||
try: | ||
self.connection = sqlite3.connect(self.db_name) | ||
except sqlite3.OperationalError as exc: | ||
raise sqlite3.OperationalError(exc) | ||
|
||
def disconnect(self) -> None: | ||
""" | ||
Disconnects from the SQLite database. | ||
""" | ||
|
||
if self.connection: | ||
self.connection.close() | ||
|
||
def execute_query(self, query: str, params: Optional[Tuple[Any, ...]] = None) -> sqlite3.Cursor: | ||
""" | ||
Executes an SQL query. | ||
Parameters: | ||
query (str): The SQL query to execute. | ||
params (tuple, optional): The parameters to be passed to the query. | ||
Returns: | ||
cursor (sqlite3.Cursor): The cursor object. | ||
""" | ||
|
||
cursor = self.connection.cursor() | ||
if params: | ||
cursor.execute(query, params) | ||
else: | ||
cursor.execute(query) | ||
self.connection.commit() | ||
return cursor | ||
|
||
def create_table(self, table_name: str, columns: List[str]) -> None: | ||
""" | ||
Creates a table in the database. | ||
Parameters: | ||
table_name (str): The name of the table to create. | ||
columns (list): The list of column definitions. | ||
""" | ||
|
||
query = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)})" | ||
self.execute_query(query) | ||
|
||
def add_column(self, table_name: str, column_name: str, column_type: str) -> None: | ||
""" | ||
Adds a column to an existing table. | ||
Parameters: | ||
table_name (str): The name of the table. | ||
column_name (str): The name of the column to add. | ||
column_type (str): The data type of the column. | ||
""" | ||
|
||
query = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}" | ||
self.execute_query(query) | ||
|
||
def remove_column(self, table_name: str, column_name: str) -> None: | ||
""" | ||
Removes a column from an existing table. | ||
Parameters: | ||
table_name (str): The name of the table. | ||
column_name (str): The name of the column to remove. | ||
""" | ||
|
||
try: | ||
query = f"ALTER TABLE {table_name} DROP COLUMN {column_name}" | ||
self.execute_query(query) | ||
except sqlite3.OperationalError as exc: | ||
query = f"PRAGMA table_info({table_name})" | ||
cursor = self.execute_query(query) | ||
columns = [column[1] for column in cursor.fetchall()] | ||
if column_name not in columns: | ||
raise ValueError(f"Column '{column_name}' does not exist in table '{table_name}'") | ||
raise sqlite3.OperationalError(exc) | ||
|
||
def update_data( | ||
self, | ||
table_name: str, | ||
column_name: str, | ||
new_value: Any, | ||
condition_column: str, | ||
condition_value: Any | ||
) -> None: | ||
""" | ||
Updates data in a table. | ||
Parameters: | ||
table_name (str): The name of the table. | ||
column_name (str): The name of the column to update. | ||
new_value (any): The new value for the column. | ||
condition_column (str): The column to use for the condition. | ||
condition_value (any): The value to use in the condition. | ||
""" | ||
|
||
query = f"UPDATE {table_name} SET {column_name} = ? WHERE {condition_column} = ?" | ||
self.execute_query(query, (new_value, condition_value)) | ||
|
||
def insert_data(self, table_name: str, values: List[Any]) -> None: | ||
""" | ||
Inserts data into a table. | ||
Parameters: | ||
table_name (str): The name of the table. | ||
values (list): The values to insert. | ||
""" | ||
|
||
placeholders = ", ".join(["?"] * len(values)) | ||
query = f"INSERT INTO {table_name} VALUES ({placeholders})" | ||
self.execute_query(query, values) | ||
|
||
def fetch_data( | ||
self, | ||
table_name: str, | ||
columns: Optional[List[str]] = None, | ||
condition: Optional[str] = None | ||
) -> List[Tuple]: | ||
""" | ||
Fetches data from a table. | ||
Parameters: | ||
table_name (str): The name of the table. | ||
columns (list, optional): The list of columns to fetch. | ||
condition (str, optional): The condition to use in the query. | ||
Returns: | ||
result (list): The fetched data. | ||
""" | ||
|
||
column_names = "*" if not columns else ", ".join(columns) | ||
query = f"SELECT {column_names} FROM {table_name}" | ||
if condition: | ||
query += f" WHERE {condition}" | ||
cursor = self.execute_query(query) | ||
return cursor.fetchall() | ||
|
||
def remove_data(self, table_name: str, condition_column: str, condition_value: Any) -> None: | ||
""" | ||
Removes data from a table. | ||
Parameters: | ||
table_name (str): The name of the table. | ||
condition_column (str): The column to use for the condition. | ||
condition_value (any): The value to use in the condition. | ||
""" | ||
|
||
query = f"DELETE FROM {table_name} WHERE {condition_column} = ?" | ||
self.execute_query(query, (condition_value,)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import pytest | ||
|
||
from wxflow import SQLiteDB | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def db(): | ||
# Create an in-memory SQLite database for testing | ||
db = SQLiteDB(":memory:") | ||
db.connect() | ||
|
||
# Create a test table | ||
table_name = "test_table" | ||
columns = ["id INTEGER PRIMARY KEY", "name TEXT", "age INTEGER"] | ||
db.create_table(table_name, columns) | ||
|
||
yield db | ||
|
||
# Disconnect from the database | ||
db.disconnect() | ||
|
||
|
||
def test_create_table(db): | ||
# Verify that the test table exists | ||
assert table_exists(db, "test_table") | ||
|
||
|
||
def test_add_column(db): | ||
# Add a new column to the test table | ||
column_name = "address" | ||
column_type = "TEXT" | ||
db.add_column("test_table", column_name, column_type) | ||
|
||
# Verify that the column exists in the test table | ||
assert column_exists(db, "test_table", column_name) | ||
|
||
|
||
def test_update_data(db): | ||
# Insert test data into the table | ||
values = [1, "Alice", 25, 'Apt 101'] | ||
db.insert_data("test_table", values) | ||
|
||
# Update the age of the record | ||
new_age = 30 | ||
db.update_data("test_table", "age", new_age, "name", "Alice") | ||
|
||
# Fetch the updated data | ||
result = db.fetch_data("test_table", condition="name='Alice'") | ||
|
||
# Verify that the age is updated correctly | ||
assert result[0][2] == new_age | ||
|
||
|
||
def test_remove_column(db): | ||
# Removes a column from the test table | ||
column_name = "address" | ||
db.remove_column("test_table", column_name) | ||
|
||
# Verify that the column exists in the test table | ||
assert not column_exists(db, "test_table", column_name) | ||
|
||
|
||
def test_remove_column_raises_error_when_column_not_exists(db): | ||
table_name = "test_table" | ||
column_name = "vacation address" | ||
|
||
with pytest.raises(ValueError, match=f"Column '{column_name}' does not exist in table '{table_name}'"): | ||
db.remove_column("test_table", column_name) | ||
|
||
|
||
def test_insert_data(db): | ||
# Insert test data into the table | ||
values = [2, "Bob", 35] | ||
db.insert_data("test_table", values) | ||
|
||
# Fetch all data from the table | ||
result = db.fetch_data("test_table") | ||
|
||
# Verify that the inserted data is present in the table | ||
assert len(result) == 2 | ||
|
||
|
||
def test_fetch_data(db): | ||
# Insert test data into the table | ||
values = [3, "Charlie", 40] | ||
db.insert_data("test_table", values) | ||
|
||
# Fetch data from the table | ||
result = db.fetch_data("test_table", condition="age > 30") | ||
|
||
# Verify that the fetched data meets the condition | ||
assert len(result) == 2 | ||
|
||
|
||
def test_remove_data(db): | ||
# Insert test data into the table | ||
values = [4, "David", 45] | ||
db.insert_data("test_table", values) | ||
|
||
# Remove a record from the table | ||
db.remove_data("test_table", "name", "David") | ||
|
||
# Fetch all data from the table | ||
result = db.fetch_data("test_table") | ||
|
||
# Verify that the removed data is not present in the table | ||
assert len(result) == 3 | ||
|
||
|
||
# Helper functions | ||
|
||
def table_exists(db, table_name): | ||
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'" | ||
cursor = db.execute_query(query) | ||
return cursor.fetchone() is not None | ||
|
||
|
||
def column_exists(db, table_name, column_name): | ||
query = f"PRAGMA table_info({table_name})" | ||
cursor = db.execute_query(query) | ||
columns = [column[1] for column in cursor.fetchall()] | ||
return column_name in columns |