Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor PSQL incident store into ORM (SQLAlchemy) #211

Merged
merged 7 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ HUB_URL ?= ""
IMPORTER_ARGS ?= ""

run-postgres:
$(CONTAINER_RUNTIME) run -it -v data:/var/lib/postgresql/data -e POSTGRES_USER=kai -e POSTGRES_PASSWORD=dog8code -e POSTGRES_DB=kai -p 5432:5432 docker.io/pgvector/pgvector:pg15
$(CONTAINER_RUNTIME) run -it -v data:/var/lib/postgresql/data -e POSTGRES_USER=kai -e POSTGRES_PASSWORD=dog8code -e POSTGRES_DB=kai -p 5432:5432 docker.io/library/postgres:16.3

run-server:
PYTHONPATH=$(KAI_PYTHON_PATH) LOGLEVEL=$(LOGLEVEL) DEMO_MODE=$(DEMO_MODE) gunicorn --timeout 3600 -w $(NUM_WORKERS) --bind localhost:8080 --worker-class aiohttp.GunicornWebWorker 'kai.server:app()'
Expand Down
18 changes: 18 additions & 0 deletions kai/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,24 @@ log_level = "info"
demo_mode = false
trace_enabled = true

# **Postgresql incident store**
# [incident_store]
# provider = "postgresql"

# [incident_store.args]
# host = "127.0.0.1"
# database = "kai"
# user = "kai"
# password = "dog8code"

# **In-memory sqlite incident store**
# ```
# [incident_store]
# provider = "sqlite"
#
# [incident_store.args]
# connection_string = "sqlite:///:memory:"

[incident_store]
provider = "postgresql"

Expand Down
3 changes: 0 additions & 3 deletions kai/data/sql/add_embedding.sql

This file was deleted.

65 changes: 0 additions & 65 deletions kai/data/sql/create_tables.sql

This file was deleted.

6 changes: 0 additions & 6 deletions kai/data/sql/drop_tables.sql

This file was deleted.

11 changes: 8 additions & 3 deletions kai/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from kai.model_provider import ModelProvider
from kai.models.analyzer_types import Incident
from kai.models.file_solution import guess_language, parse_file_solution_content
from kai.models.kai_config import KaiConfig
from kai.models.kai_config import KaiConfig, KaiConfigIncidentStoreSQLiteArgs
from kai.report import Report
from kai.service.incident_store.in_memory import InMemoryIncidentStore
from kai.service.incident_store.incident_store import Application
from kai.service.incident_store.sqlite import SQLiteIncidentStore

"""
The point of this file is to automatically see if certain prompts make the
Expand Down Expand Up @@ -172,7 +172,12 @@ def evaluate(

created_git_repo = True

incident_store = InMemoryIncidentStore(None)
incident_store = SQLiteIncidentStore(
KaiConfigIncidentStoreSQLiteArgs(
connection_string="sqlite:///:memory:"
),
None,
)
incident_store.load_report(example.application, example.report)

pb_incidents = []
Expand Down
2 changes: 1 addition & 1 deletion kai/llm_io_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_prompt(
if not fallback:
raise e

KAI_LOG.error(f"Template '{e.name}' not found. Falling back to main.jinja")
KAI_LOG.warning(f"Template '{e.name}' not found. Falling back to main.jinja")
template = jinja_env.get_template("main.jinja")

KAI_LOG.debug(f"Template {template.filename} loaded")
Expand Down
61 changes: 49 additions & 12 deletions kai/models/kai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,77 @@
from typing import Literal, Optional, Union

import yaml
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator

# Incident store providers


class KaiConfigIncidentStoreProvider(Enum):
POSTGRESQL = "postgresql"
IN_MEMORY = "in_memory"
SQLITE = "sqlite"


class KaiConfigIncidentStorePostgreSQLArgs(BaseModel):
host: str
database: str
user: str
password: str
host: Optional[str] = None
database: Optional[str] = None
user: Optional[str] = None
password: Optional[str] = None

connection_string: Optional[str] = None

@validator("connection_string", always=True)
def validate_connection_string(cls, v, values):
connection_string_present = v is not None
parameters_present = all(
values.get(key) is not None
for key in ["host", "database", "user", "password"]
)

if connection_string_present == parameters_present:
raise ValueError(
"Must provide one of connection_string or [host, database, user, password]"
)

return v


class KaiConfigIncidentStorePostgreSQL(BaseModel):
provider: Literal["postgresql"]
args: KaiConfigIncidentStorePostgreSQLArgs


class KaiConfigIncidentStoreInMemoryArgs(BaseModel):
dummy: bool
class KaiConfigIncidentStoreSQLiteArgs(BaseModel):
host: Optional[str] = None
database: Optional[str] = None
user: Optional[str] = None
password: Optional[str] = None

connection_string: Optional[str] = None

@validator("connection_string", always=True)
def validate_connection_string(cls, v, values):
connection_string_present = v is not None
parameters_present = all(
values.get(key) is not None
for key in ["host", "database", "user", "password"]
)

if connection_string_present == parameters_present:
raise ValueError(
"Must provide one of connection_string or [host, database, user, password]"
)

return v


class KaiConfigIncidentStoreInMemory(BaseModel):
provider: Literal["in_memory"]
args: KaiConfigIncidentStoreInMemoryArgs
class KaiConfigIncidentStoreSQLIte(BaseModel):
provider: Literal["sqlite"]
args: KaiConfigIncidentStoreSQLiteArgs


KaiConfigIncidentStore = Union[
KaiConfigIncidentStorePostgreSQL,
KaiConfigIncidentStoreInMemory,
KaiConfigIncidentStoreSQLIte,
]

# Model providers
Expand Down
17 changes: 13 additions & 4 deletions kai/report.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__all__ = ["Report"]

import hashlib
import json
import os
import shutil
from io import StringIO
Expand All @@ -12,9 +14,10 @@
class Report:
report: dict = None

def __init__(self, report_data: dict):
def __init__(self, report_data: dict, report_id: str):
self.workaround_counter_for_missing_ruleset_name = 0
self.report = self._reformat_report(report_data)
self.report_id = report_id

def __str__(self):
return str(self.report)
Expand All @@ -35,19 +38,25 @@ def __len__(self):
return len(self.report)

@classmethod
def load_report_from_object(cls, report_data: dict):
def load_report_from_object(cls, report_data: dict, report_id: str):
"""
Class method to create a Report instance directly from a Python dictionary object.
"""
return cls(report_data=report_data)
return cls(report_data=report_data, report_id=report_id)

@classmethod
def load_report_from_file(cls, file_name: str):
KAI_LOG.info(f"Reading report from {file_name}")
with open(file_name, "r") as f:
report: dict = yaml.safe_load(f)
report_data = report
return cls(report_data)
# report_id is the hash of the json.dumps of the report_data
return cls(
report_data,
hashlib.sha256(
json.dumps(report_data, sort_keys=True).encode()
).hexdigest(),
)

@staticmethod
def get_cleaned_file_path(uri: str):
Expand Down
8 changes: 5 additions & 3 deletions kai/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,15 @@ def app():
if config.demo_mode:
KAI_LOG.info("DEMO_MODE is enabled. LLM responses will be cached")

webapp["incident_store"] = IncidentStore.from_config(config.incident_store)
KAI_LOG.info(f"Selected incident store: {config.incident_store.provider}")

webapp["model_provider"] = ModelProvider(config.models)
KAI_LOG.info(f"Selected provider: {config.models.provider}")
KAI_LOG.info(f"Selected model: {webapp['model_provider'].model_id}")

webapp["incident_store"] = IncidentStore.from_config(
config.incident_store, webapp["model_provider"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing the model provider to the incident store so we can do LLM summaries next

)
KAI_LOG.info(f"Selected incident store: {config.incident_store.provider}")

webapp.add_routes(routes)

return webapp
Expand Down
4 changes: 2 additions & 2 deletions kai/service/incident_store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .in_memory import InMemoryIncidentStore
from .incident_store import Application, IncidentStore, Solution
from .psql import PSQLIncidentStore
from .sqlite import SQLiteIncidentStore

__all__ = [
"IncidentStore",
"Solution",
"PSQLIncidentStore",
"InMemoryIncidentStore",
"SQLiteIncidentStore",
"Application",
]
Loading
Loading