diff --git a/Makefile b/Makefile index 9575d239..9466ab0c 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ HUB_URL ?= "" IMPORTER_ARGS ?= "" run-postgres: - $(CONTAINER_RUNTIME) run -it -v data:/var/lib/postgresql/data -e POSTGRES_USER=kai -e POSTGRES_PASSWORD=dog8code -e POSTGRES_DB=kai -p 5432:5432 docker.io/pgvector/pgvector:pg15 + $(CONTAINER_RUNTIME) run -it -v data:/var/lib/postgresql/data -e POSTGRES_USER=kai -e POSTGRES_PASSWORD=dog8code -e POSTGRES_DB=kai -p 5432:5432 docker.io/library/postgres:16.3 run-server: PYTHONPATH=$(KAI_PYTHON_PATH) LOGLEVEL=$(LOGLEVEL) DEMO_MODE=$(DEMO_MODE) gunicorn --timeout 3600 -w $(NUM_WORKERS) --bind localhost:8080 --worker-class aiohttp.GunicornWebWorker 'kai.server:app()' diff --git a/kai/config.toml b/kai/config.toml index 56731729..f507822e 100644 --- a/kai/config.toml +++ b/kai/config.toml @@ -2,6 +2,24 @@ log_level = "info" demo_mode = false trace_enabled = true +# **Postgresql incident store** +# [incident_store] +# provider = "postgresql" + +# [incident_store.args] +# host = "127.0.0.1" +# database = "kai" +# user = "kai" +# password = "dog8code" + +# **In-memory sqlite incident store** +# ``` +# [incident_store] +# provider = "sqlite" +# +# [incident_store.args] +# connection_string = "sqlite:///:memory:" + [incident_store] provider = "postgresql" diff --git a/kai/data/sql/add_embedding.sql b/kai/data/sql/add_embedding.sql deleted file mode 100644 index 01d3fd14..00000000 --- a/kai/data/sql/add_embedding.sql +++ /dev/null @@ -1,3 +0,0 @@ -ALTER TABLE accepted_solutions ADD COLUMN IF NOT EXISTS small_diff_embedding vector(%s); -ALTER TABLE accepted_solutions ADD COLUMN IF NOT EXISTS original_code_embedding vector(%s); -ALTER TABLE incidents ADD COLUMN IF NOT EXISTS incident_snip_embedding vector(%s); \ No newline at end of file diff --git a/kai/data/sql/create_tables.sql b/kai/data/sql/create_tables.sql deleted file mode 100644 index 128e085e..00000000 --- a/kai/data/sql/create_tables.sql +++ /dev/null @@ -1,65 +0,0 @@ -CREATE EXTENSION IF NOT EXISTS vector; - -CREATE TABLE IF NOT EXISTS applications ( - application_id SERIAL PRIMARY KEY, - application_name TEXT NOT NULL, - repo_uri_origin TEXT NOT NULL, - repo_uri_local TEXT NOT NULL, - current_branch TEXT NOT NULL, - current_commit TEXT NOT NULL, - generated_at TIMESTAMP NOT NULL -); - -CREATE TABLE IF NOT EXISTS rulesets ( - ruleset_id SERIAL PRIMARY KEY, - ruleset_name TEXT NOT NULL, - -- application_id INT REFERENCES applications, - tags JSONB NOT NULL -); - -DO $$ BEGIN - CREATE TYPE violation_category_t AS ENUM ( - 'potential', 'optional', 'mandatory' - ); -EXCEPTION - WHEN duplicate_object THEN null; -END $$; - -CREATE TABLE IF NOT EXISTS violations ( - violation_id SERIAL PRIMARY KEY, - violation_name TEXT NOT NULL, - ruleset_id INT REFERENCES rulesets, - category violation_category_t NOT NULL, - labels JSONB NOT NULL -); - -CREATE TABLE IF NOT EXISTS accepted_solutions ( - solution_id SERIAL PRIMARY KEY, - generated_at TIMESTAMP DEFAULT current_timestamp, - solution_big_diff TEXT NOT NULL, - solution_small_diff TEXT NOT NULL, - solution_original_code TEXT NOT NULL, - solution_updated_code TEXT NOT NULL - -- small_diff_embedding vector(%s) - -- original_code_embedding vector(%s) -); - -CREATE TABLE IF NOT EXISTS incidents ( - incident_id SERIAL PRIMARY KEY, - violation_id INT REFERENCES violations, - application_id INT REFERENCES applications, - incident_uri TEXT NOT NULL, - incident_snip TEXT NOT NULL, - incident_line INT NOT NULL, - incident_variables JSONB NOT NULL, - solution_id INT REFERENCES accepted_solutions - -- incident_snip_embedding vector(%s) -); - -CREATE TABLE IF NOT EXISTS potential_solutions ( - solution_id SERIAL PRIMARY KEY, - generated_at TIMESTAMP DEFAULT current_timestamp, - solution_big_diff TEXT NOT NULL, - solution_small_diff TEXT NOT NULL, - incident_id INT REFERENCES Incidents -); diff --git a/kai/data/sql/drop_tables.sql b/kai/data/sql/drop_tables.sql deleted file mode 100644 index 764b58a9..00000000 --- a/kai/data/sql/drop_tables.sql +++ /dev/null @@ -1,6 +0,0 @@ -DROP TABLE IF EXISTS applications CASCADE; -DROP TABLE IF EXISTS rulesets CASCADE; -DROP TABLE IF EXISTS violations CASCADE; -DROP TABLE IF EXISTS accepted_solutions CASCADE; -DROP TABLE IF EXISTS incidents CASCADE; -DROP TABLE IF EXISTS potential_solutions CASCADE; \ No newline at end of file diff --git a/kai/evaluation.py b/kai/evaluation.py index 28488b0b..d2b3dfd0 100644 --- a/kai/evaluation.py +++ b/kai/evaluation.py @@ -12,10 +12,10 @@ from kai.model_provider import ModelProvider from kai.models.analyzer_types import Incident from kai.models.file_solution import guess_language, parse_file_solution_content -from kai.models.kai_config import KaiConfig +from kai.models.kai_config import KaiConfig, KaiConfigIncidentStoreSQLiteArgs from kai.report import Report -from kai.service.incident_store.in_memory import InMemoryIncidentStore from kai.service.incident_store.incident_store import Application +from kai.service.incident_store.sqlite import SQLiteIncidentStore """ The point of this file is to automatically see if certain prompts make the @@ -172,7 +172,12 @@ def evaluate( created_git_repo = True - incident_store = InMemoryIncidentStore(None) + incident_store = SQLiteIncidentStore( + KaiConfigIncidentStoreSQLiteArgs( + connection_string="sqlite:///:memory:" + ), + None, + ) incident_store.load_report(example.application, example.report) pb_incidents = [] diff --git a/kai/llm_io_handler.py b/kai/llm_io_handler.py index f2408a64..f6b4d824 100644 --- a/kai/llm_io_handler.py +++ b/kai/llm_io_handler.py @@ -82,7 +82,7 @@ def get_prompt( if not fallback: raise e - KAI_LOG.error(f"Template '{e.name}' not found. Falling back to main.jinja") + KAI_LOG.warning(f"Template '{e.name}' not found. Falling back to main.jinja") template = jinja_env.get_template("main.jinja") KAI_LOG.debug(f"Template {template.filename} loaded") diff --git a/kai/models/kai_config.py b/kai/models/kai_config.py index 2eb47e87..ab606842 100644 --- a/kai/models/kai_config.py +++ b/kai/models/kai_config.py @@ -4,21 +4,38 @@ from typing import Literal, Optional, Union import yaml -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator # Incident store providers class KaiConfigIncidentStoreProvider(Enum): POSTGRESQL = "postgresql" - IN_MEMORY = "in_memory" + SQLITE = "sqlite" class KaiConfigIncidentStorePostgreSQLArgs(BaseModel): - host: str - database: str - user: str - password: str + host: Optional[str] = None + database: Optional[str] = None + user: Optional[str] = None + password: Optional[str] = None + + connection_string: Optional[str] = None + + @validator("connection_string", always=True) + def validate_connection_string(cls, v, values): + connection_string_present = v is not None + parameters_present = all( + values.get(key) is not None + for key in ["host", "database", "user", "password"] + ) + + if connection_string_present == parameters_present: + raise ValueError( + "Must provide one of connection_string or [host, database, user, password]" + ) + + return v class KaiConfigIncidentStorePostgreSQL(BaseModel): @@ -26,18 +43,38 @@ class KaiConfigIncidentStorePostgreSQL(BaseModel): args: KaiConfigIncidentStorePostgreSQLArgs -class KaiConfigIncidentStoreInMemoryArgs(BaseModel): - dummy: bool +class KaiConfigIncidentStoreSQLiteArgs(BaseModel): + host: Optional[str] = None + database: Optional[str] = None + user: Optional[str] = None + password: Optional[str] = None + + connection_string: Optional[str] = None + + @validator("connection_string", always=True) + def validate_connection_string(cls, v, values): + connection_string_present = v is not None + parameters_present = all( + values.get(key) is not None + for key in ["host", "database", "user", "password"] + ) + + if connection_string_present == parameters_present: + raise ValueError( + "Must provide one of connection_string or [host, database, user, password]" + ) + + return v -class KaiConfigIncidentStoreInMemory(BaseModel): - provider: Literal["in_memory"] - args: KaiConfigIncidentStoreInMemoryArgs +class KaiConfigIncidentStoreSQLIte(BaseModel): + provider: Literal["sqlite"] + args: KaiConfigIncidentStoreSQLiteArgs KaiConfigIncidentStore = Union[ KaiConfigIncidentStorePostgreSQL, - KaiConfigIncidentStoreInMemory, + KaiConfigIncidentStoreSQLIte, ] # Model providers diff --git a/kai/report.py b/kai/report.py index 64dce251..7383b5cd 100644 --- a/kai/report.py +++ b/kai/report.py @@ -1,5 +1,7 @@ __all__ = ["Report"] +import hashlib +import json import os import shutil from io import StringIO @@ -12,9 +14,10 @@ class Report: report: dict = None - def __init__(self, report_data: dict): + def __init__(self, report_data: dict, report_id: str): self.workaround_counter_for_missing_ruleset_name = 0 self.report = self._reformat_report(report_data) + self.report_id = report_id def __str__(self): return str(self.report) @@ -35,11 +38,11 @@ def __len__(self): return len(self.report) @classmethod - def load_report_from_object(cls, report_data: dict): + def load_report_from_object(cls, report_data: dict, report_id: str): """ Class method to create a Report instance directly from a Python dictionary object. """ - return cls(report_data=report_data) + return cls(report_data=report_data, report_id=report_id) @classmethod def load_report_from_file(cls, file_name: str): @@ -47,7 +50,13 @@ def load_report_from_file(cls, file_name: str): with open(file_name, "r") as f: report: dict = yaml.safe_load(f) report_data = report - return cls(report_data) + # report_id is the hash of the json.dumps of the report_data + return cls( + report_data, + hashlib.sha256( + json.dumps(report_data, sort_keys=True).encode() + ).hexdigest(), + ) @staticmethod def get_cleaned_file_path(uri: str): diff --git a/kai/server.py b/kai/server.py index 2ee36ce3..d5db510f 100644 --- a/kai/server.py +++ b/kai/server.py @@ -295,13 +295,15 @@ def app(): if config.demo_mode: KAI_LOG.info("DEMO_MODE is enabled. LLM responses will be cached") - webapp["incident_store"] = IncidentStore.from_config(config.incident_store) - KAI_LOG.info(f"Selected incident store: {config.incident_store.provider}") - webapp["model_provider"] = ModelProvider(config.models) KAI_LOG.info(f"Selected provider: {config.models.provider}") KAI_LOG.info(f"Selected model: {webapp['model_provider'].model_id}") + webapp["incident_store"] = IncidentStore.from_config( + config.incident_store, webapp["model_provider"] + ) + KAI_LOG.info(f"Selected incident store: {config.incident_store.provider}") + webapp.add_routes(routes) return webapp diff --git a/kai/service/incident_store/__init__.py b/kai/service/incident_store/__init__.py index 85583c52..73c8503f 100644 --- a/kai/service/incident_store/__init__.py +++ b/kai/service/incident_store/__init__.py @@ -1,11 +1,11 @@ -from .in_memory import InMemoryIncidentStore from .incident_store import Application, IncidentStore, Solution from .psql import PSQLIncidentStore +from .sqlite import SQLiteIncidentStore __all__ = [ "IncidentStore", "Solution", "PSQLIncidentStore", - "InMemoryIncidentStore", + "SQLiteIncidentStore", "Application", ] diff --git a/kai/service/incident_store/in_memory.py b/kai/service/incident_store/in_memory.py deleted file mode 100644 index 4a251b15..00000000 --- a/kai/service/incident_store/in_memory.py +++ /dev/null @@ -1,210 +0,0 @@ -import os -from dataclasses import dataclass -from urllib.parse import unquote, urlparse - -from git import Repo - -from kai.models.kai_config import KaiConfigIncidentStoreInMemoryArgs -from kai.report import Report -from kai.service.incident_store.incident_store import ( - Application, - IncidentStore, - Solution, - remove_known_prefixes, -) - - -# These classes are just for the in-memory data store. Once we figure out the -# best way to model the incidents, violations, etc... we can remove these -@dataclass -class InMemoryIncident: - uri: str - snip: str - line: int - variables: dict - - def __hash__(self): - return hash((self.uri, self.snip, self.line, frozenset(self.variables.items()))) - - -@dataclass(frozen=True) -class InMemorySolvedIncident: - uri: str - snip: str - line: int - variables: dict - file_diff: str - repo_diff: str - original_code: str - updated_code: str - - -@dataclass -class InMemoryViolation: - unsolved_incidents: list[InMemoryIncident] - solved_incidents: list[InMemorySolvedIncident] - - -@dataclass -class InMemoryRuleset: - violations: dict[str, InMemoryViolation] - - -@dataclass -class InMemoryApplication: - current_commit: str - rulesets: dict[str, InMemoryRuleset] - - -class InMemoryIncidentStore(IncidentStore): - def __init__(self, args: KaiConfigIncidentStoreInMemoryArgs): - self.store: dict[str, InMemoryApplication] = {} - - def delete_store(self): - self.store = {} - - def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: - """ - Returns: (number_new_incidents, number_unsolved_incidents, - number_solved_incidents): tuple[int, int, int] - """ - # FIXME: Only does stuff within the same application. Maybe fixed? - - # create entries if not exists - # reference the old-new matrix - # old - # | NO | YES - # --------|--------+----------------------------- - # new NO | - | update (SOLVED, embeddings) - # YES | insert | update (line number, etc...) - - repo_path = unquote(urlparse(app.repo_uri_local).path) - print(f"repo_path: {repo_path}") - repo = Repo(repo_path) - old_commit: str - new_commit = app.current_commit - - number_new_incidents = 0 - number_unsolved_incidents = 0 - number_solved_incidents = 0 - - application = self.store.setdefault( - app.application_name, InMemoryApplication(app.current_commit, {}) - ) - - old_commit = application.current_commit - report_dict = dict(report) - - number_new_incidents = 0 - number_unsolved_incidents = 0 - number_solved_incidents = 0 - - for ruleset_name, ruleset_dict in report_dict.items(): - ruleset = application.rulesets.setdefault(ruleset_name, InMemoryRuleset({})) - - for violation_name, violation_dict in ruleset_dict.get( - "violations", {} - ).items(): - violation = ruleset.violations.setdefault( - violation_name, InMemoryViolation([], []) - ) - - store_incidents = set(violation.unsolved_incidents) - report_incidents = set( - InMemoryIncident( - x.get("uri", ""), - x.get("codeSnip", ""), - x.get("lineNumber", 0), - x.get("variables", {}), - ) - for x in violation_dict.get("incidents", []) - ) - - new_incidents = report_incidents.difference(store_incidents) - number_new_incidents += len(new_incidents) - for incident in new_incidents: - violation.unsolved_incidents.append(incident) - - unsolved_incidents = report_incidents.intersection(store_incidents) - number_unsolved_incidents += len(unsolved_incidents) - - solved_incidents = store_incidents.difference(report_incidents) - number_solved_incidents += len(solved_incidents) - for incident in solved_incidents: - file_path = os.path.join( - repo_path, - # NOTE: When retrieving uris from the report, some of - # them had "/tmp/source-code/" as their root path. - # Unsure where it originates from. - remove_known_prefixes(unquote(urlparse(incident.uri).path)), - ) - - try: - original_code = repo.git.show(f"{old_commit}:{file_path}") - except Exception: - original_code = "" - - try: - updated_code = repo.git.show(f"{new_commit}:{file_path}") - except Exception: - updated_code = "" - - repo_diff = repo.git.diff(old_commit, new_commit) - file_diff = repo.git.diff(old_commit, new_commit, "--", file_path) - - violation.solved_incidents.append( - InMemorySolvedIncident( - uri=incident.uri, - snip=incident.snip, - line=incident.line, - variables=incident.variables, - original_code=original_code, - updated_code=updated_code, - file_diff=file_diff, - repo_diff=repo_diff, - ) - ) - - return number_new_incidents, number_unsolved_incidents, number_solved_incidents - - def find_solutions( - self, - ruleset_name: str, - violation_name: str, - incident_variables: dict, - incident_snip: str | None = None, - ) -> list[Solution]: - result: list[Solution] = [] - incident_variables_set = set(incident_variables.items()) - incident_variables_set_len = len(incident_variables_set) - - for _, application in self.store.items(): - ruleset = application.rulesets.get(ruleset_name, None) - if ruleset is None: - continue - - violation = ruleset.violations.get(violation_name, None) - if violation is None: - continue - - for solved_incident in violation.solved_incidents: - if incident_snip is not None and solved_incident.snip != incident_snip: - continue - - common = set(solved_incident.variables.items()).intersection( - incident_variables_set - ) - if len(common) != incident_variables_set_len: - continue - - result.append( - Solution( - solved_incident.uri, - solved_incident.file_diff, - solved_incident.repo_diff, - solved_incident.original_code, - solved_incident.updated_code, - ) - ) - - return result diff --git a/kai/service/incident_store/incident_store.py b/kai/service/incident_store/incident_store.py index 17ba1bf4..66de755e 100644 --- a/kai/service/incident_store/incident_store.py +++ b/kai/service/incident_store/incident_store.py @@ -1,15 +1,36 @@ +import argparse import datetime +import enum import os from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional, TypeVar +from urllib.parse import unquote, urlparse import yaml from git import Repo +from sqlalchemy import ( + Column, + DateTime, + Engine, + ForeignKey, + ForeignKeyConstraint, + String, + func, + select, +) +from sqlalchemy.dialects import postgresql, sqlite +from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship +from sqlalchemy.types import JSON from kai.constants import PATH_LOCAL_REPO from kai.kai_logging import KAI_LOG -from kai.models.kai_config import KaiConfigIncidentStore, KaiConfigIncidentStoreProvider +from kai.model_provider import ModelProvider +from kai.models.kai_config import ( + KaiConfig, + KaiConfigIncidentStore, + KaiConfigIncidentStoreProvider, +) from kai.report import Report # These prefixes are sometimes in front of the paths, strip them. @@ -42,6 +63,17 @@ def filter_incident_vars(incident_vars: dict): return incident_vars +T = TypeVar("T") + + +def deep_sort(obj: T) -> T: + if isinstance(obj, dict): + return {k: deep_sort(v) for k, v in sorted(obj.items())} + if isinstance(obj, list): + return sorted(deep_sort(x) for x in obj) + return obj + + def __get_repo_path(app_name): """ Get the repo path @@ -122,8 +154,6 @@ def load_reports_from_directory(store: "IncidentStore", path: str): generated_at=datetime.datetime.now(), ) - KAI_LOG.info(f"Loading application {app}\n") - store.load_report(app_initial, Report.load_report_from_file(report_path)) KAI_LOG.info(f"Loaded application - initial {app}\n") @@ -176,31 +206,171 @@ class Solution: updated_code: Optional[str] = None +class SQLBase(DeclarativeBase): + type_annotation_map = { + dict[str, Any]: JSON() + .with_variant(postgresql.JSONB(), "postgresql") + .with_variant(sqlite.JSON(), "sqlite"), + list[str]: JSON() + .with_variant(postgresql.JSONB(), "postgresql") + .with_variant(sqlite.JSON(), "sqlite"), + } + + +class SQLUnmodifiedReport(SQLBase): + __tablename__ = "unmodified_reports" + + application_name: Mapped[str] = mapped_column(primary_key=True) + report_id: Mapped[str] = mapped_column(primary_key=True) + + generated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(), server_default=func.now() + ) + report: Mapped[dict[str, Any]] + + +class ViolationCategory(enum.Enum): + potential = "potential" + optional = "optional" + mandatory = "mandatory" + + +class SQLApplication(SQLBase): + __tablename__ = "applications" + + application_name: Mapped[str] = mapped_column(primary_key=True) + + repo_uri_origin: Mapped[str] + repo_uri_local: Mapped[str] + current_branch: Mapped[str] + current_commit: Mapped[str] + generated_at: Mapped[datetime.datetime] + + incidents: Mapped[list["SQLIncident"]] = relationship( + back_populates="application", cascade="all, delete-orphan" + ) + + +class SQLRuleset(SQLBase): + __tablename__ = "rulesets" + + ruleset_name: Mapped[str] = mapped_column(primary_key=True) + + tags: Mapped[list[str]] + + violations: Mapped[list["SQLViolation"]] = relationship( + back_populates="ruleset", cascade="all, delete-orphan" + ) + + +class SQLViolation(SQLBase): + __tablename__ = "violations" + + violation_name: Mapped[str] = mapped_column(primary_key=True) + ruleset_name: Mapped[int] = mapped_column( + ForeignKey("rulesets.ruleset_name"), primary_key=True + ) + + category: Mapped[ViolationCategory] + labels: Mapped[list[str]] + + ruleset: Mapped[SQLRuleset] = relationship(back_populates="violations") + incidents: Mapped[list["SQLIncident"]] = relationship( + back_populates="violation", cascade="all, delete-orphan" + ) + + +class SQLAcceptedSolution(SQLBase): + __tablename__ = "accepted_solutions" + + solution_id: Mapped[int] = mapped_column(primary_key=True) + + generated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(), server_default=func.now() + ) + solution_big_diff: Mapped[str] + solution_small_diff: Mapped[str] + solution_original_code: Mapped[str] + solution_updated_code: Mapped[str] + + llm_summary: Mapped[Optional[str]] + + incidents: Mapped[list["SQLIncident"]] = relationship( + back_populates="solution", cascade="all, delete-orphan" + ) + + def __repr__(self): + return f"SQLAcceptedSolution(solution_id={self.solution_id}, generated_at={self.generated_at}, solution_big_diff={self.solution_big_diff:.10}, solution_small_diff={self.solution_small_diff:.10}, solution_original_code={self.solution_original_code:.10}, solution_updated_code={self.solution_updated_code:.10})" + + +class SQLIncident(SQLBase): + __tablename__ = "incidents" + + incident_id: Mapped[int] = mapped_column(primary_key=True) + + violation_name = Column(String) + ruleset_name = Column(String) + application_name: Mapped[str] = mapped_column( + ForeignKey("applications.application_name") + ) + incident_uri: Mapped[str] + incident_snip: Mapped[str] + incident_line: Mapped[int] + incident_variables: Mapped[dict[str, Any]] + solution_id: Mapped[Optional[str]] = mapped_column( + ForeignKey("accepted_solutions.solution_id") + ) + + __table_args__ = ( + ForeignKeyConstraint( + [violation_name, ruleset_name], + [SQLViolation.violation_name, SQLViolation.ruleset_name], + ), + {}, + ) + + violation: Mapped[SQLViolation] = relationship(back_populates="incidents") + application: Mapped[SQLApplication] = relationship(back_populates="incidents") + solution: Mapped[SQLAcceptedSolution] = relationship(back_populates="incidents") + + def __repr__(self) -> str: + return f"SQLIncident(violation_name={self.violation_name}, ruleset_name={self.ruleset_name}, application_name={self.application_name}, incident_uri={self.incident_uri}, incident_snip={self.incident_snip:.10}, incident_line={self.incident_line}, incident_variables={self.incident_variables}, solution_id={self.solution_id})" + + class IncidentStore(ABC): + """ + Responsible for 3 main things: + - Incident/Solution storage + - Solution detection + - Solution generation + """ + + engine: Engine + model_provider: ModelProvider @staticmethod - def from_config(config: KaiConfigIncidentStore): + def from_config(config: KaiConfigIncidentStore, model_provider: ModelProvider): """ Factory method to produce whichever incident store is needed. """ + # TODO: Come up with some sort of "solution generator strategy" so we + # don't blow up our llm API usage. Lazy, immediate, other etc... + if config.provider == "postgresql": from kai.service.incident_store.psql import PSQLIncidentStore - return PSQLIncidentStore(config.args) - elif config.provider == "in_memory": - from kai.service.incident_store.in_memory import InMemoryIncidentStore + return PSQLIncidentStore(config.args, model_provider) + elif config.provider == "sqlite": + from kai.service.incident_store.sqlite import SQLiteIncidentStore - return InMemoryIncidentStore(config.args) + return SQLiteIncidentStore(config.args, model_provider) else: raise ValueError( f"Unsupported provider: {config.provider}\ntype: {type(config.provider)}\nlmao: {KaiConfigIncidentStoreProvider.POSTGRESQL}" ) - @abstractmethod - def load_report( - self, application: Application, report: Report - ) -> tuple[int, int, int]: + def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: """ Load incidents from a report and given application object. Returns a tuple containing (# of new incidents, # of unsolved incidents, # of @@ -208,16 +378,223 @@ def load_report( NOTE: This application object is more like metadata than anything. """ - pass - @abstractmethod + # FIXME: Only does stuff within the same application. Maybe fixed? + + # NEW: Store whole report in table + # - if we get the same report again, we should skip adding it. Have some identifier + # - But should still check incidents. + + # - have two tables `unsolved_incidents` and `solved_incidents` + + # Iterate through all incidents in the report + # - change so theres an identified like "commit application ruleset violation" + + # create entries if not exists + # reference the old-new matrix + # old + # | NO | YES + # --------|--------+----------------------------- + # new NO | - | update (SOLVED, embeddings) + # YES | insert | update (line number, etc...) + + repo_path = unquote(urlparse(app.repo_uri_local).path) + repo = Repo(repo_path) + old_commit: str + new_commit = app.current_commit + + number_new_incidents = 0 + number_unsolved_incidents = 0 + number_solved_incidents = 0 + + with Session(self.engine) as session: + incidents_temp: list[SQLIncident] = [] + + select_application_stmt = select(SQLApplication).where( + SQLApplication.application_name == app.application_name + ) + + application = session.scalars(select_application_stmt).first() + + if application is None: + application = SQLApplication( + application_name=app.application_name, + repo_uri_origin=app.repo_uri_origin, + repo_uri_local=app.repo_uri_local, + current_branch=app.current_branch, + current_commit=app.current_commit, + generated_at=app.generated_at, + ) + session.add(application) + session.commit() + + # TODO: Determine if we want to have this check + # if application.generated_at >= app.generated_at: + # return 0, 0, 0 + + old_commit = application.current_commit + + report_dict = dict(report) + + for ruleset_name, ruleset_dict in report_dict.items(): + select_ruleset_stmt = select(SQLRuleset).where( + SQLRuleset.ruleset_name == ruleset_name + ) + + ruleset = session.scalars(select_ruleset_stmt).first() + + if ruleset is None: + ruleset = SQLRuleset( + ruleset_name=ruleset_name, + tags=ruleset_dict.get("tags", []), + ) + session.add(ruleset) + session.commit() + + for violation_name, violation_dict in ruleset_dict.get( + "violations", {} + ).items(): + select_violation_stmt = ( + select(SQLViolation) + .where(SQLViolation.violation_name == violation_name) + .where(SQLViolation.ruleset_name == ruleset.ruleset_name) + ) + + violation = session.scalars(select_violation_stmt).first() + + if violation is None: + violation = SQLViolation( + violation_name=violation_name, + ruleset_name=ruleset.ruleset_name, + category=violation_dict.get("category", "potential"), + labels=violation_dict.get("labels", []), + ) + session.add(violation) + session.commit() + + for incident in violation_dict.get("incidents", []): + incidents_temp.append( + SQLIncident( + violation_name=violation.violation_name, + ruleset_name=ruleset.ruleset_name, + application_name=application.application_name, + incident_uri=incident.get("uri", ""), + incident_snip=incident.get("codeSnip", ""), + incident_line=incident.get("lineNumber", 0), + incident_variables=deep_sort( + incident.get("variables", {}) + ), + ) + ) + + # incidents_temp - incidents + new_incidents = set(incidents_temp) - set(application.incidents) + number_new_incidents = len(new_incidents) + + for new_incident in new_incidents: + session.add(new_incident) + + session.commit() + + # incidents `intersect` incidents_temp + unsolved_incidents = set(application.incidents).intersection(incidents_temp) + number_unsolved_incidents = len(unsolved_incidents) + + # incidents - incidents_temp + solved_incidents = set(application.incidents) - set(incidents_temp) + number_solved_incidents = len(solved_incidents) + KAI_LOG.debug(f"Number of solved incidents: {len(solved_incidents)}") + # KAI_LOG.debug(f"{solved_incidents=}") + + for solved_incident in solved_incidents: + file_path = os.path.join( + repo_path, + # NOTE: When retrieving uris from the report, some of them + # had "/tmp/source-code/" as their root path. Unsure where + # it originates from. + unquote(urlparse(solved_incident.incident_uri).path).removeprefix( + "/tmp/source-code/" # trunk-ignore(bandit/B108) + ), + ) + + # NOTE: The `big_diff` functionality is currently disabled + + # big_diff: str = repo.git.diff(old_commit, new_commit).encode('utf-8', errors="ignore").decode() + big_diff = "" + + # TODO: Some of the sample repos have invalid utf-8 characters, + # thus the encode-then-decode hack. Not very performant, there's + # probably a better way to handle this. + + try: + original_code = ( + repo.git.show(f"{old_commit}:{file_path}") + .encode("utf-8", errors="ignore") + .decode() + ) + except Exception: + original_code = "" + + try: + updated_code = ( + repo.git.show(f"{new_commit}:{file_path}") + .encode("utf-8", errors="ignore") + .decode() + ) + except Exception: + updated_code = "" + + small_diff = ( + repo.git.diff(old_commit, new_commit, "--", file_path) + .encode("utf-8", errors="ignore") + .decode() + ) + + # TODO: Strings must be utf-8 encodable, so I'm removing the `big_diff` functionality for now + solved_incident.solution = SQLAcceptedSolution( + generated_at=app.generated_at, + solution_big_diff=big_diff, + solution_small_diff=small_diff, + solution_original_code=original_code, + solution_updated_code=updated_code, + ) + + session.commit() + + application.repo_uri_origin = app.repo_uri_origin + application.repo_uri_local = app.repo_uri_local + application.current_branch = app.current_branch + application.current_commit = app.current_commit + application.generated_at = app.generated_at + + session.commit() + + unmodified_report = SQLUnmodifiedReport( + application_name=app.application_name, + report_id=report.report_id, + generated_at=app.generated_at, + report=report_dict, + ) + + # Upsert the unmodified report + session.merge(unmodified_report) + session.commit() + + return number_new_incidents, number_unsolved_incidents, number_solved_incidents + + def create_tables(self): + """ + Create tables in the incident store. + """ + SQLBase.metadata.create_all(self.engine) + def delete_store(self): """ Clears all data within the incident store. Non-reversible! """ - pass + SQLBase.metadata.drop_all(self.engine) + self.create_tables() - @abstractmethod def find_solutions( self, ruleset_name: str, @@ -228,4 +605,97 @@ def find_solutions( """ Returns a list of solutions for the given incident. Exact matches only. """ + incident_variables = deep_sort(filter_incident_vars(incident_variables)) + + if incident_snip is None: + incident_snip = "" + + with Session(self.engine) as session: + select_violation_stmt = ( + select(SQLViolation) + .where(SQLViolation.violation_name == violation_name) + .where(SQLViolation.ruleset_name == ruleset_name) + ) + + violation = session.scalars(select_violation_stmt).first() + + if violation is None: + return [] + + select_incidents_with_solutions_stmt = ( + select(SQLIncident) + .where(SQLIncident.violation_name == violation.violation_name) + .where(SQLIncident.ruleset_name == violation.ruleset_name) + .where(SQLIncident.solution_id.isnot(None)) + .where(self.json_exactly_equal(incident_variables)) + ) + + result: list[Solution] = [] + for incident in session.execute( + select_incidents_with_solutions_stmt + ).scalars(): + select_accepted_solution_stmt = select(SQLAcceptedSolution).where( + SQLAcceptedSolution.solution_id == incident.solution_id + ) + + accepted_solution = session.scalars( + select_accepted_solution_stmt + ).first() + + result.append( + Solution( + uri=incident.incident_uri, + file_diff=accepted_solution.solution_small_diff, + repo_diff=accepted_solution.solution_big_diff, + ) + ) + + return result + + @abstractmethod + def json_exactly_equal(self, json_dict: dict): + """ + Each incident store must implement this method as JSON is handled + slightly differently between SQLite and PostgreSQL. + """ pass + + +def cmd(provider: str = None): + KAI_LOG.setLevel("debug".upper()) + + parser = argparse.ArgumentParser(description="Process some parameters.") + parser.add_argument( + "--config_filepath", + type=str, + default="../../config.toml", + help="Path to the config file.", + ) + parser.add_argument( + "--drop_tables", type=str, default="False", help="Whether to drop tables." + ) + parser.add_argument( + "--analysis_dir_path", + type=str, + default="../../samples/analysis_reports", + help="Path to analysis reports folder", + ) + + args = parser.parse_args() + + config = KaiConfig.model_validate_filepath(args.config_filepath) + + if provider is not None and config.incident_store.provider != provider: + raise Exception(f"This script only works with {provider} incident store.") + + model_provider = ModelProvider(config.models) + incident_store = IncidentStore.from_config(config.incident_store, model_provider) + + if args.drop_tables: + incident_store.delete_store() + + load_reports_from_directory(incident_store, args.analysis_dir_path) + + +if __name__ == "__main__": + cmd() diff --git a/kai/service/incident_store/psql.py b/kai/service/incident_store/psql.py index 1caf5edf..5744d693 100644 --- a/kai/service/incident_store/psql.py +++ b/kai/service/incident_store/psql.py @@ -1,855 +1,30 @@ -import argparse -import datetime -import json -import os -from functools import wraps -from inspect import signature -from urllib.parse import unquote, urlparse +from sqlalchemy import and_, create_engine -import psycopg2 -from git import Repo -from psycopg2.extensions import connection -from psycopg2.extras import DictCursor, DictRow +from kai.model_provider import ModelProvider +from kai.models.kai_config import KaiConfigIncidentStorePostgreSQLArgs +from kai.service.incident_store.incident_store import IncidentStore, SQLIncident, cmd -from kai.constants import PATH_SQL -from kai.embedding_provider import EmbeddingNone -from kai.kai_logging import KAI_LOG -from kai.models.kai_config import KaiConfig, KaiConfigIncidentStorePostgreSQLArgs -from kai.report import Report -from kai.service.incident_store.incident_store import ( - Application, - IncidentStore, - Solution, - filter_incident_vars, - load_reports_from_directory, - remove_known_prefixes, -) - -def supply_cursor_if_none(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - sig = signature(func) - bound_args = sig.bind(self, *args, **kwargs) - bound_args.apply_defaults() - - if "cur" not in bound_args.arguments or bound_args.arguments["cur"] is None: - with self.conn.cursor() as cur: - bound_args.arguments["cur"] = cur - return func(*bound_args.args, **bound_args.kwargs) - else: - return func(*bound_args.args, **bound_args.kwargs) - - return wrapper - - -# TODO(@JonahSussman): Migrate this to use an ORM class PSQLIncidentStore(IncidentStore): - def __init__(self, args: KaiConfigIncidentStorePostgreSQLArgs): - self.emb_provider = EmbeddingNone() - - try: - with psycopg2.connect( - cursor_factory=DictCursor, **args.model_dump() - ) as conn: - KAI_LOG.info("Connected to the PostgreSQL server.") - self.conn: connection = conn - self.conn.autocommit = True - - self.create_tables() - - except (psycopg2.DatabaseError, Exception) as error: - KAI_LOG.error(f"Error initializing PSQLIncidentStore: {error}") - - def create_tables(self): - # TODO: Figure out portable way to install the pgvector extension. - # Containerize? "CREATE EXTENSION IF NOT EXISTS vector;" Only works as - # superuser - - # TODO: along with analyzer_types.py, we should really use something - # like openapi to nail down the spec and autogenerate the types - sql_create_tables = open( - os.path.join(PATH_SQL, "create_tables.sql"), "r" - ).read() - sql_add_embedding = open( - os.path.join(PATH_SQL, "add_embedding.sql"), "r" - ).readlines() - - with self.conn.cursor() as cur: - cur.execute(sql_create_tables) - - dim = self.emb_provider.get_dimension() - for q in sql_add_embedding: - cur.execute(q, (dim,)) - - # Abstract base class implementations - - def delete_store(self): - with self.conn.cursor() as cur: - cur.execute(open(os.path.join(PATH_SQL, "drop_tables.sql"), "r").read()) - - self.create_tables() - - def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: - """ - Returns: (number_new_incidents, number_unsolved_incidents, - number_solved_incidents): tuple[int, int, int] - """ - # FIXME: Only does stuff within the same application. Maybe fixed? - - # create entries if not exists - # reference the old-new matrix - # old - # | NO | YES - # --------|--------+----------------------------- - # new NO | - | update (SOLVED, embeddings) - # YES | insert | update (line number, etc...) - - repo_path = unquote(urlparse(app.repo_uri_local).path) - repo = Repo(repo_path) - old_commit: str - new_commit = app.current_commit - - number_new_incidents = 0 - number_unsolved_incidents = 0 - number_solved_incidents = 0 - - with self.conn.cursor() as cur: - cur.execute("DROP TABLE IF EXISTS incidents_temp;") - cur.execute("CREATE TABLE incidents_temp (LIKE incidents INCLUDING ALL);") - - query_app = self.select_application(None, app.application_name, cur) - - if len(query_app) >= 2: - raise Exception(f"Multiple applications found for {app}.") - elif len(query_app) == 0: - application = self.insert_application(app, cur) - else: - application = query_app[0] - - old_commit = application["current_commit"] - - report_dict = dict(report) - - for ruleset_name, ruleset_dict in report_dict.items(): - query_ruleset = self.select_ruleset( - ruleset_name=ruleset_name, - # application_id=application['application_id'], - cur=cur, - ) - - if len(query_ruleset) >= 2: - raise Exception("Multiple rulesets found.") - elif len(query_ruleset) == 0: - ruleset = self.insert_ruleset( - ruleset_name=ruleset_name, - # application_id=application['application_id'], - tags=ruleset_dict.get("tags", []), - cur=cur, - ) - else: - ruleset = query_ruleset[0] - - for violation_name, violation_dict in ruleset_dict.get( - "violations", {} - ).items(): - query_violation = self.select_violation( - violation_name=violation_name, - ruleset_id=ruleset["ruleset_id"], - cur=cur, - ) - - if len(query_violation) >= 2: - raise Exception("Multiple rulesets found.") - elif len(query_violation) == 0: - violation = self.insert_violation( - violation_name=violation_name, - ruleset_id=ruleset["ruleset_id"], - category=violation_dict.get("category", "potential"), - labels=violation_dict.get("labels", []), - cur=cur, - ) - else: - violation = query_violation[0] - - for incident in violation_dict.get("incidents", []): - incident_vars = filter_incident_vars( - incident.get("variables", {}) - ) - cur.execute( - """INSERT INTO incidents_temp(violation_id, application_id, incident_uri, incident_snip, incident_line, incident_variables) - VALUES (%s, %s, %s, %s, %s, %s);""", - ( - violation["violation_id"], - application["application_id"], - incident.get("uri", ""), - incident.get("codeSnip", ""), - incident.get("lineNumber", 0), - json.dumps(incident_vars), - ), - ) - - # incidents_temp - incidents - cur.execute( - """WITH filtered_incidents_temp AS ( - SELECT * FROM incidents_temp WHERE application_id = %s -), -filtered_incidents AS ( - SELECT * FROM incidents WHERE application_id = %s -) -SELECT fit.incident_id AS incidents_temp_id, fi.incident_id AS incidents_id, fit.violation_id, fit.application_id, fit.incident_uri, fit.incident_snip, fit.incident_line, fit.incident_variables -FROM filtered_incidents_temp fit -LEFT JOIN filtered_incidents fi ON fit.violation_id = fi.violation_id - AND fit.incident_uri = fi.incident_uri - AND fit.incident_snip = fi.incident_snip - AND fit.incident_line = fi.incident_line - AND fit.incident_variables = fi.incident_variables -WHERE fi.incident_id IS NULL;""", - ( - application["application_id"], - application["application_id"], - ), - ) - - new_incidents = cur.fetchall() - number_new_incidents = len(new_incidents) - - self.conn.autocommit = False - for ni in new_incidents: - self.insert_incident( - ni["violation_id"], - ni["application_id"], - ni["incident_uri"], - ni["incident_snip"], - ni["incident_line"], - ni["incident_variables"], - None, - cur, - ) - self.conn.commit() - cur.fetchall() - self.conn.autocommit = True - - # incidents `intersect` incidents_temp - cur.execute( - """-- incidents `intersect` incidents_temp with application_id match first -WITH filtered_incidents AS ( - SELECT * FROM incidents WHERE application_id = %s -), -filtered_incidents_temp AS ( - SELECT * FROM incidents_temp WHERE application_id = %s -) -SELECT fi.incident_id AS incidents_id, fit.incident_id AS incidents_temp_id, fi.violation_id, fi.application_id, fi.incident_uri, fi.incident_snip, fi.incident_line, fi.incident_variables -FROM filtered_incidents fi -JOIN filtered_incidents_temp fit ON fi.violation_id = fit.violation_id - AND fi.incident_uri = fit.incident_uri - AND fi.incident_snip = fit.incident_snip - AND fi.incident_line = fit.incident_line - AND fi.incident_variables = fit.incident_variables; -""", - ( - application["application_id"], - application["application_id"], - ), - ) - - unsolved_incidents = cur.fetchall() - number_unsolved_incidents = len(unsolved_incidents) - - # incidents - incidents_temp - cur.execute( - """WITH filtered_incidents AS ( - SELECT * FROM incidents WHERE application_id = %s -), -filtered_incidents_temp AS ( - SELECT * FROM incidents_temp WHERE application_id = %s -) -SELECT fi.incident_id AS incidents_id, fit.incident_id AS incidents_temp_id, fi.violation_id, fi.application_id, fi.incident_uri, fi.incident_snip, fi.incident_line, fi.incident_variables -FROM filtered_incidents fi -LEFT JOIN filtered_incidents_temp fit ON fi.violation_id = fit.violation_id - AND fi.incident_uri = fit.incident_uri - AND fi.incident_snip = fit.incident_snip - AND fi.incident_line = fit.incident_line - AND fi.incident_variables = fit.incident_variables -WHERE fit.incident_id IS NULL;""", - ( - application["application_id"], - application["application_id"], - ), - ) - - solved_incidents = cur.fetchall() - number_solved_incidents = len(solved_incidents) - KAI_LOG.debug(f"# of solved inc: {len(solved_incidents)}") - - self.conn.autocommit = False - for si in solved_incidents: - # NOTE: When retrieving uris from the report, some of them - # had "/tmp/source-code/" as their root path. Unsure where - # it originates from. - file_path = remove_known_prefixes(unquote(urlparse(si[4]).path)) - # file_path = os.path.join( - # repo_path, - # in_repo_path, - # ) - big_diff = repo.git.diff(old_commit, new_commit) - - try: - original_code = repo.git.show(f"{old_commit}:{file_path}") - except Exception as e: - KAI_LOG.error(e) - original_code = "" - - try: - updated_code = repo.git.show(f"{new_commit}:{file_path}") - except Exception as e: - KAI_LOG.error(e) - updated_code = "" - - # file_path = pathlib.Path(os.path.join(repo_path, unquote(urlparse(si[3]).path).removeprefix('/tmp/source-code'))).as_uri() - small_diff = repo.git.diff(old_commit, new_commit, "--", file_path) - KAI_LOG.debug(small_diff) - - sln = self.insert_accepted_solution( - app.generated_at, - big_diff, - small_diff, - original_code, - updated_code, - cur, - ) - - cur.execute( - "UPDATE incidents SET solution_id = %s WHERE incident_id = %s;", - (sln["solution_id"], si[0]), - ) - - self.conn.commit() - self.conn.autocommit = True - - cur.execute("DROP TABLE IF EXISTS incidents_temp;") - application = self.update_application( - application["application_id"], app, cur - ) - - return number_new_incidents, number_unsolved_incidents, number_solved_incidents - - def find_solutions( - self, - ruleset_name: str, - violation_name: str, - incident_variables: dict, - incident_snip: str | None = None, - ) -> list[Solution]: - if incident_snip is None: - incident_snip = "" - - with self.conn.cursor() as cur: - incident_vars_str = json.dumps(filter_incident_vars(incident_variables)) - - cur.execute( - """ - SELECT v.* - FROM violations v - JOIN rulesets r ON v.ruleset_id = r.ruleset_id - WHERE v.violation_name = %s - AND r.ruleset_name = %s; - """, - (violation_name, ruleset_name), - ) - - violation_query = cur.fetchall() - if len(violation_query) > 1: - raise Exception( - f"More than one violation with name '{violation_name}' and ruleset name {ruleset_name}" - ) - if len(violation_query) == 0: - return [] - - violation = violation_query[0] - - cur.execute( - """ - SELECT COUNT(*) - FROM incidents - WHERE violation_id = %s - AND solution_id IS NOT NULL; - """, - (violation["violation_id"],), - ) - - number_of_slns = cur.fetchone()[0] - if number_of_slns == 0: - return [] - - cur.execute( - """ - SELECT * - FROM incidents - WHERE violation_id = %s - AND solution_id IS NOT NULL - AND incident_variables <@ %s - AND incident_variables @> %s; - """, - (violation["violation_id"], incident_vars_str, incident_vars_str), - ) - - incidents_with_solutions = cur.fetchall() - result: list[Solution] = [] - - for incident in incidents_with_solutions: - accepted_solution = self.select_accepted_solution( - incident["solution_id"], cur - ) - - result.append( - Solution( - uri=incident["incident_uri"], - file_diff=accepted_solution["solution_small_diff"], - repo_diff=accepted_solution["solution_big_diff"], - ) - ) - - return result - - # Implementation specific to PSQLIncidentStore Methods - - @supply_cursor_if_none - def select_application( - self, app_id: int, app_name: str = None, cur: DictCursor = None - ) -> list[DictRow]: - if app_id is None and app_name is None: - return [] - - if app_id is not None: - cur.execute( - "SELECT * FROM applications WHERE application_id = %s;", (app_id,) - ) - elif app_name is not None: - cur.execute( - "SELECT * FROM applications WHERE application_name = %s;", (app_name,) - ) - else: - raise Exception("At least one of app_id or app_name must be not None.") - - # return [Application.from_dict_row(row) for row in cur.fetchall()] - return cur.fetchall() - - @supply_cursor_if_none - def insert_application(self, app: Application, cur: DictCursor = None) -> DictRow: - cur.execute( - """INSERT INTO applications(application_name, repo_uri_origin, repo_uri_local, current_branch, current_commit, generated_at) - VALUES (%s, %s, %s, %s, %s, %s) RETURNING *;""", - ( - app.application_name, - app.repo_uri_origin, - app.repo_uri_local, - app.current_branch, - app.current_commit, - app.generated_at, - ), - ) - - return cur.fetchone() - - @supply_cursor_if_none - def update_application( - self, application_id: int, app: Application, cur: DictCursor = None - ) -> DictRow: - cur.execute( - """UPDATE applications - SET application_name = %s, - repo_uri_origin = %s, - repo_uri_local = %s, - current_branch = %s, - current_commit = %s, - generated_at = %s - WHERE application_id = %s - RETURNING *;""", - ( - app.application_name, - app.repo_uri_origin, - app.repo_uri_local, - app.current_branch, - app.current_commit, - app.generated_at, - application_id, - ), - ) - - return cur.fetchone() - - @supply_cursor_if_none - def select_ruleset( - self, ruleset_id: int = None, ruleset_name: str = None, cur: DictCursor = None - ) -> list[DictRow]: - # def select_ruleset(self, ruleset_id: int = None, ruleset_name: str = None, application_id: int = None, cur: DictCursor = None) -> list[DictRow]: - if ruleset_id is not None: - cur.execute("SELECT * FROM rulesets WHERE ruleset_id = %s;", (ruleset_id,)) - # elif ruleset_name is not None and application_id is not None: - elif ruleset_name is not None: - cur.execute( - # "SELECT * FROM rulesets WHERE ruleset_name = %s AND application_id = %s;", - "SELECT * FROM rulesets WHERE ruleset_name = %s;", - # (ruleset_name, application_id,) - (ruleset_name,), - ) - else: - raise Exception( - "At least one of ruleset_id or ruleset_name must be not None." - ) - - return cur.fetchall() - - @supply_cursor_if_none - def insert_ruleset( - self, ruleset_name: str, tags: list[str], cur: DictCursor = None - ) -> DictRow: - # def insert_ruleset(self, ruleset_name: str, application_id: int, tags: list[str], cur: DictCursor = None) -> DictRow: - cur.execute( - # """INSERT INTO rulesets(ruleset_name, application_id, tags) - """INSERT INTO rulesets(ruleset_name, tags) - VALUES (%s, %s) RETURNING*;""", - # VALUES (%s, %s, %s) RETURNING*;""", - # (ruleset_name, application_id, json.dumps(tags)) - (ruleset_name, json.dumps(tags)), - ) - - return cur.fetchone() - - @supply_cursor_if_none - def select_violation( - self, - violation_id: int = None, - violation_name: str = None, - ruleset_id: int = None, - cur: DictCursor = None, - ) -> list[DictRow]: - if violation_id is not None: - cur.execute( - "SELECT * FROM violations WHERE violation_id = %s;", (violation_id,) - ) - elif violation_name is not None and ruleset_id is not None: - cur.execute( - "SELECT * FROM violations WHERE violation_name = %s AND ruleset_id = %s;", - ( - violation_name, - ruleset_id, - ), - ) - else: - raise Exception( - "At least one of violation_id or (violation_name, ruleset_id) must be not None." - ) - - return cur.fetchall() - - @supply_cursor_if_none - def insert_violation( - self, - violation_name: str, - ruleset_id: int, - category: str, - labels: list[str], - cur: DictCursor = None, - ) -> DictRow: - cur.execute( - """INSERT INTO violations(violation_name, ruleset_id, category, labels) - VALUES (%s, %s, %s, %s) RETURNING *;""", - (violation_name, ruleset_id, category, json.dumps(sorted(labels))), - ) - - return cur.fetchone() - - @supply_cursor_if_none - def insert_incident( - self, - violation_id: int, - application_id: int, - incident_uri: str, - incident_snip: str, - incident_line: int, - incident_variables: dict, - solution_id: int = None, - cur: DictCursor = None, - ) -> DictRow: - # if isinstance(incident_variables, str): - # incident_variables = json.loads(incident_variables) - # if not isinstance(incident_variables, list): - # raise Exception(f"incident_variables must be of type list. Got type '{type(incident_variables)}'") - - vars_str = json.dumps(filter_incident_vars(incident_variables)) - truncated_vars = (vars_str[:75] + "...") if len(vars_str) > 75 else vars_str - - KAI_LOG.info( - f"Inserting incident {(violation_id, application_id, incident_uri, incident_line, truncated_vars, solution_id,)}" - ) - - cur.execute( - """INSERT INTO incidents(violation_id, application_id, incident_uri, incident_snip, incident_line, incident_variables, solution_id, incident_snip_embedding) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s) RETURNING *;""", - ( - violation_id, - application_id, - incident_uri, - incident_snip, - incident_line, - json.dumps(incident_variables), - solution_id, - str(self.emb_provider.get_embedding(incident_snip)), - ), - ) - - return cur.fetchone() - - @supply_cursor_if_none - def insert_accepted_solution( - self, - generated_at: datetime.datetime, - solution_big_diff: str, - solution_small_diff: str, - solution_original_code: str, - solution_updated_code: str, - cur: DictCursor = None, - ): - KAI_LOG.info(f"Inserting accepted solution {((generated_at))}") - small_diff_embedding = str(self.emb_provider.get_embedding(solution_small_diff)) - original_code_embedding = str( - self.emb_provider.get_embedding(solution_original_code) - ) - - # Encode the strings using the appropriate encoding method - # to avoid unicode errors TODO: validate if this is the right way to do it - solution_big_diff = solution_big_diff.encode("utf-8", "ignore").decode("utf-8") - solution_small_diff = solution_small_diff.encode("utf-8", "ignore").decode( - "utf-8" - ) - solution_original_code = ( - solution_original_code.encode("utf-8", "ignore") - .decode("utf-8") - .replace("\x00", "\uFFFD") - ) - solution_updated_code = ( - solution_updated_code.encode("utf-8", "ignore") - .decode("utf-8") - .replace("\x00", "\uFFFD") - ) - - cur.execute( - """INSERT INTO accepted_solutions(generated_at, solution_big_diff, - solution_small_diff, solution_original_code, solution_updated_code, - small_diff_embedding, original_code_embedding) - VALUES (%s, %s, %s, %s, %s, %s, %s) RETURNING *;""", - ( - generated_at, - solution_big_diff, - solution_small_diff, - solution_original_code, - solution_updated_code, - small_diff_embedding, - original_code_embedding, - ), - ) - - return cur.fetchone() - - @supply_cursor_if_none - def select_accepted_solution(self, solution_id: int, cur: DictCursor = None): - cur.execute( - "SELECT * FROM accepted_solutions WHERE solution_id = %s;", (solution_id,) - ) - - return cur.fetchone() - - @supply_cursor_if_none - def get_fuzzy_similar_incident( - self, - violation_name: str, - ruleset_name: str, - incident_snip: str, - incident_vars: dict, - cur: DictCursor = None, + def __init__( + self, args: KaiConfigIncidentStorePostgreSQLArgs, model_provider: ModelProvider ): - """ - Returns tuple[DictRow | None, str] - First element is the match if it exists - - Second element is whether it is an exact match or not. Values can be: - - 'exact': exact match. From the same violation and has the same - variables. Filtered using similarity search - - 'variables_mismatch': From the same violation but does not have the same - variables. - - 'similarity_only': Not from the same violation, only based on snip - similarity search - - 'unseen_violation': We haven't seen this violation before. Same result - as 'similarity_only' - - 'ambiguous_violation': violation_name and ruleset_name did not uniquely - identify a violation. Same result as 'similarity_only' - """ - - # # Pseudo-code - # this_violation = get_violation_from_params() - # if this_violation_dne: - # return get_snip_with_highest_embedding_similarity_from_all_violations(), 'similarity_only' - - # this_violation_slns = get_solutions_for_this_violation() - - # if len(this_violation_slns) == 0: - # return get_snip_with_highest_embedding_similarity_from_all_violations(), 'similarity_only' - - # # The violation we are looking at has at least one solution - # filter_on_vars = this_violation_slns.filter_exact(inp_vars) - - # if len(filter_on_vars) == 0: - # return get_snip_with_highest_embedding_similarity_from_all_solutions(), 'variables_mismatch' - # if len(filter_on_vars) == 1: - # return filter_on_vars[0], 'exact' - # if len(filter_on_vars) > 1: - # return get_snip_with_highest_embedding_similarity_from_filtered_set(), 'exact' - - KAI_LOG.debug("get_fuzzy_similar_incident") - - emb = self.emb_provider.get_embedding(incident_snip) - emb_str = str(emb) - - incident_vars_str = json.dumps(filter_incident_vars(incident_vars)) - - def highest_embedding_similarity_from_all(): - cur.execute( - """ - SELECT * - FROM incidents - WHERE solution_id IS NOT NULL - ORDER BY incident_snip_embedding <-> %s LIMIT 1;""", - (emb_str,), - ) - return dict(cur.fetchone()) - - cur.execute( - """ - SELECT v.* - FROM violations v - JOIN rulesets r ON v.ruleset_id = r.ruleset_id - WHERE v.violation_name = %s - AND r.ruleset_name = %s; - """, - (violation_name, ruleset_name), - ) - - violation_query = cur.fetchall() - - if len(violation_query) > 1: - KAI_LOG.info("Ambiguous violation based on ruleset_name and violation_name") - return highest_embedding_similarity_from_all(), "ambiguous_violation" - if len(violation_query) == 0: - KAI_LOG.info(f"No violations matched: {ruleset_name=} {violation_name=}") - return highest_embedding_similarity_from_all(), "unseen_violation" - - violation = violation_query[0] - - cur.execute( - """ - SELECT COUNT(*) - FROM incidents - WHERE violation_id = %s - AND solution_id IS NOT NULL; - """, - (violation["violation_id"],), - ) - - number_of_slns = cur.fetchone()[0] - if number_of_slns == 0: - KAI_LOG.info( - f"No solutions for violation: {ruleset_name=} {violation_name=}" - ) - return highest_embedding_similarity_from_all(), "similarity_only" - - cur.execute( - """ - SELECT * - FROM incidents - WHERE violation_id = %s - AND solution_id IS NOT NULL - AND incident_variables <@ %s - AND incident_variables @> %s; - """, - (violation["violation_id"], incident_vars_str, incident_vars_str), - ) - - exact_variables_query = cur.fetchall() - - if len(exact_variables_query) == 1: - return dict(exact_variables_query[0]), "exact" - elif len(exact_variables_query) == 0: - cur.execute( - """ - SELECT * - FROM incidents - WHERE violation_id = %s - AND solution_id IS NOT NULL - ORDER BY incident_snip_embedding <-> %s - LIMIT 1; - """, - ( - violation["violation_id"], - emb_str, - ), - ) - return dict(cur.fetchone()), "variables_mismatch" + if args.connection_string: + self.engine = create_engine(args.connection_string) else: - cur.execute( - """ - SELECT * - FROM incidents - WHERE violation_id = %s - AND solution_id IS NOT NULL - AND incident_variables <@ %s - AND incident_variables @> %s - ORDER BY incident_snip_embedding <-> %s - LIMIT 1; - """, - ( - violation["violation_id"], - incident_vars_str, - incident_vars_str, - emb_str, - ), + self.engine = create_engine( + f"postgresql://{args.user}:{args.password}@{args.host}:5432/{args.database}", + client_encoding="utf8", ) - return dict(cur.fetchone()), "exact" - - -def main(): - KAI_LOG.setLevel("debug".upper()) - - parser = argparse.ArgumentParser(description="Process some parameters.") - parser.add_argument( - "--config_filepath", - type=str, - default="../../config.toml", - help="Path to the config file.", - ) - parser.add_argument( - "--drop_tables", type=str, default="False", help="Whether to drop tables." - ) - parser.add_argument( - "--analysis_dir_path", - type=str, - default="../../samples/analysis_reports", - help="Path to analysis reports folder", - ) - - args = parser.parse_args() - config = KaiConfig.model_validate_filepath(args.config_filepath) + self.model_provider = model_provider - if config.incident_store.provider != "postgresql": - raise Exception("This script only works with PostgreSQL incident store.") - - incident_store = PSQLIncidentStore(config.incident_store.args) - - if args.drop_tables: - incident_store.delete_store() - - load_reports_from_directory(incident_store, args.analysis_dir_path) + def json_exactly_equal(self, json_dict: dict): + return and_( + SQLIncident.incident_variables.op("<@")(json_dict), + SQLIncident.incident_variables.op("@>")(json_dict), + ) if __name__ == "__main__": - main() + cmd("postgresql") diff --git a/kai/service/incident_store/sqlite.py b/kai/service/incident_store/sqlite.py new file mode 100644 index 00000000..cf9aa931 --- /dev/null +++ b/kai/service/incident_store/sqlite.py @@ -0,0 +1,37 @@ +from sqlalchemy import bindparam, create_engine, text + +from kai.model_provider import ModelProvider +from kai.models.kai_config import KaiConfigIncidentStoreSQLiteArgs +from kai.service.incident_store.incident_store import IncidentStore + + +class SQLiteIncidentStore(IncidentStore): + def __init__( + self, args: KaiConfigIncidentStoreSQLiteArgs, model_provider: ModelProvider + ): + if args.connection_string: + self.engine = create_engine(args.connection_string) + else: + self.engine = create_engine( + f"sqlite://{args.user}:{args.password}@{args.host}:5432/{args.database}", + client_encoding="utf8", + ) + + self.model_provider = model_provider + + def json_exactly_equal(self, json_dict: dict): + return text( + """ + ( + SELECT key, value + FROM json_tree(SQLIncident.incident_variables) + WHERE type != 'object' + ORDER BY key, value + ) = ( + SELECT key, value + FROM json_tree(:json_dict) + WHERE type != 'object' + ORDER BY key, value + ) + """ + ).bindparams(bindparam("json_dict", json_dict))