From 747a4b1616a8bcdbe39782f40f4d525885f4a20d Mon Sep 17 00:00:00 2001 From: JonahSussman Date: Wed, 10 Jul 2024 15:25:45 -0400 Subject: [PATCH] Implemented report_id and SQLite provider Signed-off-by: JonahSussman --- kai/config.toml | 18 + kai/evaluation.py | 11 +- kai/models/kai_config.py | 61 ++- kai/report.py | 17 +- kai/server.py | 8 +- kai/service/incident_store/__init__.py | 4 +- kai/service/incident_store/in_memory.py | 210 -------- kai/service/incident_store/incident_store.py | 511 ++++++++++++++++++- kai/service/incident_store/psql.py | 485 +----------------- kai/service/incident_store/sqlite.py | 37 ++ 10 files changed, 648 insertions(+), 714 deletions(-) delete mode 100644 kai/service/incident_store/in_memory.py create mode 100644 kai/service/incident_store/sqlite.py diff --git a/kai/config.toml b/kai/config.toml index 56731729..f507822e 100644 --- a/kai/config.toml +++ b/kai/config.toml @@ -2,6 +2,24 @@ log_level = "info" demo_mode = false trace_enabled = true +# **Postgresql incident store** +# [incident_store] +# provider = "postgresql" + +# [incident_store.args] +# host = "127.0.0.1" +# database = "kai" +# user = "kai" +# password = "dog8code" + +# **In-memory sqlite incident store** +# ``` +# [incident_store] +# provider = "sqlite" +# +# [incident_store.args] +# connection_string = "sqlite:///:memory:" + [incident_store] provider = "postgresql" diff --git a/kai/evaluation.py b/kai/evaluation.py index 28488b0b..d2b3dfd0 100644 --- a/kai/evaluation.py +++ b/kai/evaluation.py @@ -12,10 +12,10 @@ from kai.model_provider import ModelProvider from kai.models.analyzer_types import Incident from kai.models.file_solution import guess_language, parse_file_solution_content -from kai.models.kai_config import KaiConfig +from kai.models.kai_config import KaiConfig, KaiConfigIncidentStoreSQLiteArgs from kai.report import Report -from kai.service.incident_store.in_memory import InMemoryIncidentStore from kai.service.incident_store.incident_store import Application +from kai.service.incident_store.sqlite import SQLiteIncidentStore """ The point of this file is to automatically see if certain prompts make the @@ -172,7 +172,12 @@ def evaluate( created_git_repo = True - incident_store = InMemoryIncidentStore(None) + incident_store = SQLiteIncidentStore( + KaiConfigIncidentStoreSQLiteArgs( + connection_string="sqlite:///:memory:" + ), + None, + ) incident_store.load_report(example.application, example.report) pb_incidents = [] diff --git a/kai/models/kai_config.py b/kai/models/kai_config.py index 2eb47e87..ab606842 100644 --- a/kai/models/kai_config.py +++ b/kai/models/kai_config.py @@ -4,21 +4,38 @@ from typing import Literal, Optional, Union import yaml -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator # Incident store providers class KaiConfigIncidentStoreProvider(Enum): POSTGRESQL = "postgresql" - IN_MEMORY = "in_memory" + SQLITE = "sqlite" class KaiConfigIncidentStorePostgreSQLArgs(BaseModel): - host: str - database: str - user: str - password: str + host: Optional[str] = None + database: Optional[str] = None + user: Optional[str] = None + password: Optional[str] = None + + connection_string: Optional[str] = None + + @validator("connection_string", always=True) + def validate_connection_string(cls, v, values): + connection_string_present = v is not None + parameters_present = all( + values.get(key) is not None + for key in ["host", "database", "user", "password"] + ) + + if connection_string_present == parameters_present: + raise ValueError( + "Must provide one of connection_string or [host, database, user, password]" + ) + + return v class KaiConfigIncidentStorePostgreSQL(BaseModel): @@ -26,18 +43,38 @@ class KaiConfigIncidentStorePostgreSQL(BaseModel): args: KaiConfigIncidentStorePostgreSQLArgs -class KaiConfigIncidentStoreInMemoryArgs(BaseModel): - dummy: bool +class KaiConfigIncidentStoreSQLiteArgs(BaseModel): + host: Optional[str] = None + database: Optional[str] = None + user: Optional[str] = None + password: Optional[str] = None + + connection_string: Optional[str] = None + + @validator("connection_string", always=True) + def validate_connection_string(cls, v, values): + connection_string_present = v is not None + parameters_present = all( + values.get(key) is not None + for key in ["host", "database", "user", "password"] + ) + + if connection_string_present == parameters_present: + raise ValueError( + "Must provide one of connection_string or [host, database, user, password]" + ) + + return v -class KaiConfigIncidentStoreInMemory(BaseModel): - provider: Literal["in_memory"] - args: KaiConfigIncidentStoreInMemoryArgs +class KaiConfigIncidentStoreSQLIte(BaseModel): + provider: Literal["sqlite"] + args: KaiConfigIncidentStoreSQLiteArgs KaiConfigIncidentStore = Union[ KaiConfigIncidentStorePostgreSQL, - KaiConfigIncidentStoreInMemory, + KaiConfigIncidentStoreSQLIte, ] # Model providers diff --git a/kai/report.py b/kai/report.py index 64dce251..7383b5cd 100644 --- a/kai/report.py +++ b/kai/report.py @@ -1,5 +1,7 @@ __all__ = ["Report"] +import hashlib +import json import os import shutil from io import StringIO @@ -12,9 +14,10 @@ class Report: report: dict = None - def __init__(self, report_data: dict): + def __init__(self, report_data: dict, report_id: str): self.workaround_counter_for_missing_ruleset_name = 0 self.report = self._reformat_report(report_data) + self.report_id = report_id def __str__(self): return str(self.report) @@ -35,11 +38,11 @@ def __len__(self): return len(self.report) @classmethod - def load_report_from_object(cls, report_data: dict): + def load_report_from_object(cls, report_data: dict, report_id: str): """ Class method to create a Report instance directly from a Python dictionary object. """ - return cls(report_data=report_data) + return cls(report_data=report_data, report_id=report_id) @classmethod def load_report_from_file(cls, file_name: str): @@ -47,7 +50,13 @@ def load_report_from_file(cls, file_name: str): with open(file_name, "r") as f: report: dict = yaml.safe_load(f) report_data = report - return cls(report_data) + # report_id is the hash of the json.dumps of the report_data + return cls( + report_data, + hashlib.sha256( + json.dumps(report_data, sort_keys=True).encode() + ).hexdigest(), + ) @staticmethod def get_cleaned_file_path(uri: str): diff --git a/kai/server.py b/kai/server.py index 2ee36ce3..d5db510f 100644 --- a/kai/server.py +++ b/kai/server.py @@ -295,13 +295,15 @@ def app(): if config.demo_mode: KAI_LOG.info("DEMO_MODE is enabled. LLM responses will be cached") - webapp["incident_store"] = IncidentStore.from_config(config.incident_store) - KAI_LOG.info(f"Selected incident store: {config.incident_store.provider}") - webapp["model_provider"] = ModelProvider(config.models) KAI_LOG.info(f"Selected provider: {config.models.provider}") KAI_LOG.info(f"Selected model: {webapp['model_provider'].model_id}") + webapp["incident_store"] = IncidentStore.from_config( + config.incident_store, webapp["model_provider"] + ) + KAI_LOG.info(f"Selected incident store: {config.incident_store.provider}") + webapp.add_routes(routes) return webapp diff --git a/kai/service/incident_store/__init__.py b/kai/service/incident_store/__init__.py index 85583c52..73c8503f 100644 --- a/kai/service/incident_store/__init__.py +++ b/kai/service/incident_store/__init__.py @@ -1,11 +1,11 @@ -from .in_memory import InMemoryIncidentStore from .incident_store import Application, IncidentStore, Solution from .psql import PSQLIncidentStore +from .sqlite import SQLiteIncidentStore __all__ = [ "IncidentStore", "Solution", "PSQLIncidentStore", - "InMemoryIncidentStore", + "SQLiteIncidentStore", "Application", ] diff --git a/kai/service/incident_store/in_memory.py b/kai/service/incident_store/in_memory.py deleted file mode 100644 index 4a251b15..00000000 --- a/kai/service/incident_store/in_memory.py +++ /dev/null @@ -1,210 +0,0 @@ -import os -from dataclasses import dataclass -from urllib.parse import unquote, urlparse - -from git import Repo - -from kai.models.kai_config import KaiConfigIncidentStoreInMemoryArgs -from kai.report import Report -from kai.service.incident_store.incident_store import ( - Application, - IncidentStore, - Solution, - remove_known_prefixes, -) - - -# These classes are just for the in-memory data store. Once we figure out the -# best way to model the incidents, violations, etc... we can remove these -@dataclass -class InMemoryIncident: - uri: str - snip: str - line: int - variables: dict - - def __hash__(self): - return hash((self.uri, self.snip, self.line, frozenset(self.variables.items()))) - - -@dataclass(frozen=True) -class InMemorySolvedIncident: - uri: str - snip: str - line: int - variables: dict - file_diff: str - repo_diff: str - original_code: str - updated_code: str - - -@dataclass -class InMemoryViolation: - unsolved_incidents: list[InMemoryIncident] - solved_incidents: list[InMemorySolvedIncident] - - -@dataclass -class InMemoryRuleset: - violations: dict[str, InMemoryViolation] - - -@dataclass -class InMemoryApplication: - current_commit: str - rulesets: dict[str, InMemoryRuleset] - - -class InMemoryIncidentStore(IncidentStore): - def __init__(self, args: KaiConfigIncidentStoreInMemoryArgs): - self.store: dict[str, InMemoryApplication] = {} - - def delete_store(self): - self.store = {} - - def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: - """ - Returns: (number_new_incidents, number_unsolved_incidents, - number_solved_incidents): tuple[int, int, int] - """ - # FIXME: Only does stuff within the same application. Maybe fixed? - - # create entries if not exists - # reference the old-new matrix - # old - # | NO | YES - # --------|--------+----------------------------- - # new NO | - | update (SOLVED, embeddings) - # YES | insert | update (line number, etc...) - - repo_path = unquote(urlparse(app.repo_uri_local).path) - print(f"repo_path: {repo_path}") - repo = Repo(repo_path) - old_commit: str - new_commit = app.current_commit - - number_new_incidents = 0 - number_unsolved_incidents = 0 - number_solved_incidents = 0 - - application = self.store.setdefault( - app.application_name, InMemoryApplication(app.current_commit, {}) - ) - - old_commit = application.current_commit - report_dict = dict(report) - - number_new_incidents = 0 - number_unsolved_incidents = 0 - number_solved_incidents = 0 - - for ruleset_name, ruleset_dict in report_dict.items(): - ruleset = application.rulesets.setdefault(ruleset_name, InMemoryRuleset({})) - - for violation_name, violation_dict in ruleset_dict.get( - "violations", {} - ).items(): - violation = ruleset.violations.setdefault( - violation_name, InMemoryViolation([], []) - ) - - store_incidents = set(violation.unsolved_incidents) - report_incidents = set( - InMemoryIncident( - x.get("uri", ""), - x.get("codeSnip", ""), - x.get("lineNumber", 0), - x.get("variables", {}), - ) - for x in violation_dict.get("incidents", []) - ) - - new_incidents = report_incidents.difference(store_incidents) - number_new_incidents += len(new_incidents) - for incident in new_incidents: - violation.unsolved_incidents.append(incident) - - unsolved_incidents = report_incidents.intersection(store_incidents) - number_unsolved_incidents += len(unsolved_incidents) - - solved_incidents = store_incidents.difference(report_incidents) - number_solved_incidents += len(solved_incidents) - for incident in solved_incidents: - file_path = os.path.join( - repo_path, - # NOTE: When retrieving uris from the report, some of - # them had "/tmp/source-code/" as their root path. - # Unsure where it originates from. - remove_known_prefixes(unquote(urlparse(incident.uri).path)), - ) - - try: - original_code = repo.git.show(f"{old_commit}:{file_path}") - except Exception: - original_code = "" - - try: - updated_code = repo.git.show(f"{new_commit}:{file_path}") - except Exception: - updated_code = "" - - repo_diff = repo.git.diff(old_commit, new_commit) - file_diff = repo.git.diff(old_commit, new_commit, "--", file_path) - - violation.solved_incidents.append( - InMemorySolvedIncident( - uri=incident.uri, - snip=incident.snip, - line=incident.line, - variables=incident.variables, - original_code=original_code, - updated_code=updated_code, - file_diff=file_diff, - repo_diff=repo_diff, - ) - ) - - return number_new_incidents, number_unsolved_incidents, number_solved_incidents - - def find_solutions( - self, - ruleset_name: str, - violation_name: str, - incident_variables: dict, - incident_snip: str | None = None, - ) -> list[Solution]: - result: list[Solution] = [] - incident_variables_set = set(incident_variables.items()) - incident_variables_set_len = len(incident_variables_set) - - for _, application in self.store.items(): - ruleset = application.rulesets.get(ruleset_name, None) - if ruleset is None: - continue - - violation = ruleset.violations.get(violation_name, None) - if violation is None: - continue - - for solved_incident in violation.solved_incidents: - if incident_snip is not None and solved_incident.snip != incident_snip: - continue - - common = set(solved_incident.variables.items()).intersection( - incident_variables_set - ) - if len(common) != incident_variables_set_len: - continue - - result.append( - Solution( - solved_incident.uri, - solved_incident.file_diff, - solved_incident.repo_diff, - solved_incident.original_code, - solved_incident.updated_code, - ) - ) - - return result diff --git a/kai/service/incident_store/incident_store.py b/kai/service/incident_store/incident_store.py index 98d1675c..3f86f9fd 100644 --- a/kai/service/incident_store/incident_store.py +++ b/kai/service/incident_store/incident_store.py @@ -1,15 +1,36 @@ +import argparse import datetime +import enum import os from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional, TypeVar +from urllib.parse import unquote, urlparse import yaml from git import Repo +from sqlalchemy import ( + Column, + DateTime, + Engine, + ForeignKey, + ForeignKeyConstraint, + String, + func, + select, +) +from sqlalchemy.dialects import postgresql, sqlite +from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship +from sqlalchemy.types import JSON from kai.constants import PATH_LOCAL_REPO from kai.kai_logging import KAI_LOG -from kai.models.kai_config import KaiConfigIncidentStore, KaiConfigIncidentStoreProvider +from kai.model_provider import ModelProvider +from kai.models.kai_config import ( + KaiConfig, + KaiConfigIncidentStore, + KaiConfigIncidentStoreProvider, +) from kai.report import Report # These prefixes are sometimes in front of the paths, strip them. @@ -41,6 +62,17 @@ def filter_incident_vars(incident_vars: dict): return incident_vars +T = TypeVar("T") + + +def deep_sort(obj: T) -> T: + if isinstance(obj, dict): + return {k: deep_sort(v) for k, v in sorted(obj.items())} + if isinstance(obj, list): + return sorted(deep_sort(x) for x in obj) + return obj + + def __get_repo_path(app_name): """ Get the repo path @@ -173,31 +205,180 @@ class Solution: updated_code: Optional[str] = None +class SQLBase(DeclarativeBase): + type_annotation_map = { + dict[str, Any]: JSON() + .with_variant(postgresql.JSONB(), "postgresql") + .with_variant(sqlite.JSON(), "sqlite"), + list[str]: JSON() + .with_variant(postgresql.JSONB(), "postgresql") + .with_variant(sqlite.JSON(), "sqlite"), + } + + +class SQLUnmodifiedReport(SQLBase): + __tablename__ = "unmodified_reports" + + application_name: Mapped[str] = mapped_column(primary_key=True) + report_id: Mapped[str] = mapped_column(primary_key=True) + + generated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(), server_default=func.now() + ) + report: Mapped[dict[str, Any]] + + +class ViolationCategory(enum.Enum): + potential = "potential" + optional = "optional" + mandatory = "mandatory" + + +class SQLApplication(SQLBase): + __tablename__ = "applications" + + application_name: Mapped[str] = mapped_column(primary_key=True) + + repo_uri_origin: Mapped[str] + repo_uri_local: Mapped[str] + current_branch: Mapped[str] + current_commit: Mapped[str] + generated_at: Mapped[datetime.datetime] + + incidents: Mapped[list["SQLIncident"]] = relationship( + back_populates="application", cascade="all, delete-orphan" + ) + + +class SQLRuleset(SQLBase): + __tablename__ = "rulesets" + + ruleset_name: Mapped[str] = mapped_column(primary_key=True) + + tags: Mapped[list[str]] + + violations: Mapped[list["SQLViolation"]] = relationship( + back_populates="ruleset", cascade="all, delete-orphan" + ) + + +class SQLViolation(SQLBase): + __tablename__ = "violations" + + violation_name: Mapped[str] = mapped_column(primary_key=True) + ruleset_name: Mapped[int] = mapped_column( + ForeignKey("rulesets.ruleset_name"), primary_key=True + ) + + category: Mapped[ViolationCategory] + labels: Mapped[list[str]] + + ruleset: Mapped[SQLRuleset] = relationship(back_populates="violations") + incidents: Mapped[list["SQLIncident"]] = relationship( + back_populates="violation", cascade="all, delete-orphan" + ) + + +class SQLAcceptedSolution(SQLBase): + __tablename__ = "accepted_solutions" + + solution_id: Mapped[int] = mapped_column(primary_key=True) + + generated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(), server_default=func.now() + ) + solution_big_diff: Mapped[str] + solution_small_diff: Mapped[str] + solution_original_code: Mapped[str] + solution_updated_code: Mapped[str] + + llm_summary: Mapped[Optional[str]] + + incidents: Mapped[list["SQLIncident"]] = relationship( + back_populates="solution", cascade="all, delete-orphan" + ) + + def __repr__(self): + return f"SQLAcceptedSolution(solution_id={self.solution_id}, generated_at={self.generated_at}, solution_big_diff={self.solution_big_diff:.10}, solution_small_diff={self.solution_small_diff:.10}, solution_original_code={self.solution_original_code:.10}, solution_updated_code={self.solution_updated_code:.10})" + + +class SQLIncident(SQLBase): + __tablename__ = "incidents" + + incident_id: Mapped[int] = mapped_column(primary_key=True) + + violation_name = Column(String) + ruleset_name = Column(String) + application_name: Mapped[str] = mapped_column( + ForeignKey("applications.application_name") + ) + incident_uri: Mapped[str] + incident_snip: Mapped[str] + incident_line: Mapped[int] + incident_variables: Mapped[dict[str, Any]] + solution_id: Mapped[Optional[str]] = mapped_column( + ForeignKey("accepted_solutions.solution_id") + ) + + __table_args__ = ( + ForeignKeyConstraint( + [violation_name, ruleset_name], + [SQLViolation.violation_name, SQLViolation.ruleset_name], + ), + {}, + ) + + violation: Mapped[SQLViolation] = relationship(back_populates="incidents") + application: Mapped[SQLApplication] = relationship(back_populates="incidents") + solution: Mapped[SQLAcceptedSolution] = relationship(back_populates="incidents") + + def __repr__(self) -> str: + return f"SQLIncident(violation_name={self.violation_name}, ruleset_name={self.ruleset_name}, application_name={self.application_name}, incident_uri={self.incident_uri}, incident_snip={self.incident_snip:.10}, incident_line={self.incident_line}, incident_variables={self.incident_variables}, solution_id={self.solution_id})" + + +# def dump(sql, *multiparams, **params): +# print(sql.compile(dialect=engine.dialect)) + +# engine = create_engine('postgresql://', strategy='mock', executor=dump) +# Base.metadata.create_all(engine, checkfirst=False) + +# exit() + + class IncidentStore(ABC): + """ + Responsible for 3 main things: + - Incident/Solution storage + - Solution detection + - Solution generation + """ + + engine: Engine + model_provider: ModelProvider @staticmethod - def from_config(config: KaiConfigIncidentStore): + def from_config(config: KaiConfigIncidentStore, model_provider: ModelProvider): """ Factory method to produce whichever incident store is needed. """ + # TODO: Come up with some sort of "solution generator strategy" so we + # don't blow up our llm API usage. Lazy, immediate, other etc... + if config.provider == "postgresql": from kai.service.incident_store.psql import PSQLIncidentStore - return PSQLIncidentStore(config.args) - elif config.provider == "in_memory": - from kai.service.incident_store.in_memory import InMemoryIncidentStore + return PSQLIncidentStore(config.args, model_provider) + elif config.provider == "sqlite": + from kai.service.incident_store.sqlite import SQLiteIncidentStore - return InMemoryIncidentStore(config.args) + return SQLiteIncidentStore(config.args, model_provider) else: raise ValueError( f"Unsupported provider: {config.provider}\ntype: {type(config.provider)}\nlmao: {KaiConfigIncidentStoreProvider.POSTGRESQL}" ) - @abstractmethod - def load_report( - self, application: Application, report: Report - ) -> tuple[int, int, int]: + def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: """ Load incidents from a report and given application object. Returns a tuple containing (# of new incidents, # of unsolved incidents, # of @@ -205,16 +386,223 @@ def load_report( NOTE: This application object is more like metadata than anything. """ - pass - @abstractmethod + # FIXME: Only does stuff within the same application. Maybe fixed? + + # NEW: Store whole report in table + # - if we get the same report again, we should skip adding it. Have some identifier + # - But should still check incidents. + + # - have two tables `unsolved_incidents` and `solved_incidents` + + # Iterate through all incidents in the report + # - change so theres an identified like "commit application ruleset violation" + + # create entries if not exists + # reference the old-new matrix + # old + # | NO | YES + # --------|--------+----------------------------- + # new NO | - | update (SOLVED, embeddings) + # YES | insert | update (line number, etc...) + + repo_path = unquote(urlparse(app.repo_uri_local).path) + repo = Repo(repo_path) + old_commit: str + new_commit = app.current_commit + + number_new_incidents = 0 + number_unsolved_incidents = 0 + number_solved_incidents = 0 + + with Session(self.engine) as session: + incidents_temp: list[SQLIncident] = [] + + select_application_stmt = select(SQLApplication).where( + SQLApplication.application_name == app.application_name + ) + + application = session.scalars(select_application_stmt).first() + + if application is None: + application = SQLApplication( + application_name=app.application_name, + repo_uri_origin=app.repo_uri_origin, + repo_uri_local=app.repo_uri_local, + current_branch=app.current_branch, + current_commit=app.current_commit, + generated_at=app.generated_at, + ) + session.add(application) + session.commit() + + # TODO: Determine if we want to have this check + # if application.generated_at >= app.generated_at: + # return 0, 0, 0 + + old_commit = application.current_commit + + report_dict = dict(report) + + for ruleset_name, ruleset_dict in report_dict.items(): + select_ruleset_stmt = select(SQLRuleset).where( + SQLRuleset.ruleset_name == ruleset_name + ) + + ruleset = session.scalars(select_ruleset_stmt).first() + + if ruleset is None: + ruleset = SQLRuleset( + ruleset_name=ruleset_name, + tags=ruleset_dict.get("tags", []), + ) + session.add(ruleset) + session.commit() + + for violation_name, violation_dict in ruleset_dict.get( + "violations", {} + ).items(): + select_violation_stmt = ( + select(SQLViolation) + .where(SQLViolation.violation_name == violation_name) + .where(SQLViolation.ruleset_name == ruleset.ruleset_name) + ) + + violation = session.scalars(select_violation_stmt).first() + + if violation is None: + violation = SQLViolation( + violation_name=violation_name, + ruleset_name=ruleset.ruleset_name, + category=violation_dict.get("category", "potential"), + labels=violation_dict.get("labels", []), + ) + session.add(violation) + session.commit() + + for incident in violation_dict.get("incidents", []): + incidents_temp.append( + SQLIncident( + violation_name=violation.violation_name, + ruleset_name=ruleset.ruleset_name, + application_name=application.application_name, + incident_uri=incident.get("uri", ""), + incident_snip=incident.get("codeSnip", ""), + incident_line=incident.get("lineNumber", 0), + incident_variables=deep_sort( + incident.get("variables", {}) + ), + ) + ) + + # incidents_temp - incidents + new_incidents = set(incidents_temp) - set(application.incidents) + number_new_incidents = len(new_incidents) + + for new_incident in new_incidents: + session.add(new_incident) + + session.commit() + + # incidents `intersect` incidents_temp + unsolved_incidents = set(application.incidents).intersection(incidents_temp) + number_unsolved_incidents = len(unsolved_incidents) + + # incidents - incidents_temp + solved_incidents = set(application.incidents) - set(incidents_temp) + number_solved_incidents = len(solved_incidents) + KAI_LOG.debug(f"Number of solved incidents: {len(solved_incidents)}") + # KAI_LOG.debug(f"{solved_incidents=}") + + for solved_incident in solved_incidents: + file_path = os.path.join( + repo_path, + # NOTE: When retrieving uris from the report, some of them + # had "/tmp/source-code/" as their root path. Unsure where + # it originates from. + unquote(urlparse(solved_incident.incident_uri).path).removeprefix( + "/tmp/source-code/" # trunk-ignore(bandit/B108) + ), + ) + + # NOTE: The `big_diff` functionality is currently disabled + + # big_diff: str = repo.git.diff(old_commit, new_commit).encode('utf-8', errors="ignore").decode() + big_diff = "" + + # TODO: Some of the sample repos have invalid utf-8 characters, + # thus the encode-then-decode hack. Not very performant, there's + # probably a better way to handle this. + + try: + original_code = ( + repo.git.show(f"{old_commit}:{file_path}") + .encode("utf-8", errors="ignore") + .decode() + ) + except Exception: + original_code = "" + + try: + updated_code = ( + repo.git.show(f"{new_commit}:{file_path}") + .encode("utf-8", errors="ignore") + .decode() + ) + except Exception: + updated_code = "" + + small_diff = ( + repo.git.diff(old_commit, new_commit, "--", file_path) + .encode("utf-8", errors="ignore") + .decode() + ) + + # TODO: Strings must be utf-8 encodable, so I'm removing the `big_diff` functionality for now + solved_incident.solution = SQLAcceptedSolution( + generated_at=app.generated_at, + solution_big_diff=big_diff, + solution_small_diff=small_diff, + solution_original_code=original_code, + solution_updated_code=updated_code, + ) + + session.commit() + + application.repo_uri_origin = app.repo_uri_origin + application.repo_uri_local = app.repo_uri_local + application.current_branch = app.current_branch + application.current_commit = app.current_commit + application.generated_at = app.generated_at + + session.commit() + + unmodified_report = SQLUnmodifiedReport( + application_name=app.application_name, + report_id=report.report_id, + generated_at=app.generated_at, + report=report_dict, + ) + + # Upsert the unmodified report + session.merge(unmodified_report) + session.commit() + + return number_new_incidents, number_unsolved_incidents, number_solved_incidents + + def create_tables(self): + """ + Create tables in the incident store. + """ + SQLBase.metadata.create_all(self.engine) + def delete_store(self): """ Clears all data within the incident store. Non-reversible! """ - pass + SQLBase.metadata.drop_all(self.engine) + self.create_tables() - @abstractmethod def find_solutions( self, ruleset_name: str, @@ -225,4 +613,97 @@ def find_solutions( """ Returns a list of solutions for the given incident. Exact matches only. """ + incident_variables = deep_sort(filter_incident_vars(incident_variables)) + + if incident_snip is None: + incident_snip = "" + + with Session(self.engine) as session: + select_violation_stmt = ( + select(SQLViolation) + .where(SQLViolation.violation_name == violation_name) + .where(SQLViolation.ruleset_name == ruleset_name) + ) + + violation = session.scalars(select_violation_stmt).first() + + if violation is None: + return [] + + select_incidents_with_solutions_stmt = ( + select(SQLIncident) + .where(SQLIncident.violation_name == violation.violation_name) + .where(SQLIncident.ruleset_name == violation.ruleset_name) + .where(SQLIncident.solution_id.isnot(None)) + .where(self.json_exactly_equal(incident_variables)) + ) + + result: list[Solution] = [] + for incident in session.execute( + select_incidents_with_solutions_stmt + ).scalars(): + select_accepted_solution_stmt = select(SQLAcceptedSolution).where( + SQLAcceptedSolution.solution_id == incident.solution_id + ) + + accepted_solution = session.scalars( + select_accepted_solution_stmt + ).first() + + result.append( + Solution( + uri=incident.incident_uri, + file_diff=accepted_solution.solution_small_diff, + repo_diff=accepted_solution.solution_big_diff, + ) + ) + + return result + + @abstractmethod + def json_exactly_equal(self, json_dict: dict): + """ + Each incident store must implement this method as JSON is handled + slightly differently between SQLite and PostgreSQL. + """ pass + + +def cmd(provider: str = None): + KAI_LOG.setLevel("debug".upper()) + + parser = argparse.ArgumentParser(description="Process some parameters.") + parser.add_argument( + "--config_filepath", + type=str, + default="../../config.toml", + help="Path to the config file.", + ) + parser.add_argument( + "--drop_tables", type=str, default="False", help="Whether to drop tables." + ) + parser.add_argument( + "--analysis_dir_path", + type=str, + default="../../samples/analysis_reports", + help="Path to analysis reports folder", + ) + + args = parser.parse_args() + + config = KaiConfig.model_validate_filepath(args.config_filepath) + + if provider is not None and config.incident_store.provider != provider: + raise Exception(f"This script only works with {provider} incident store.") + + model_provider = ModelProvider(config.models) + incident_store = IncidentStore.from_config(config.incident_store, model_provider) + + if args.drop_tables: + incident_store.delete_store() + + load_reports_from_directory(incident_store, args.analysis_dir_path) + + +if __name__ == "__main__": + cmd() diff --git a/kai/service/incident_store/psql.py b/kai/service/incident_store/psql.py index 4e9ae6ef..5744d693 100644 --- a/kai/service/incident_store/psql.py +++ b/kai/service/incident_store/psql.py @@ -1,475 +1,30 @@ -import argparse -import datetime -import enum -import os -from typing import Any, Optional -from urllib.parse import unquote, urlparse +from sqlalchemy import and_, create_engine -from git import Repo -from sqlalchemy import ( - Column, - DateTime, - ForeignKey, - ForeignKeyConstraint, - String, - create_engine, - func, - select, -) -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship -from sqlalchemy.types import JSON - -from kai.kai_logging import KAI_LOG -from kai.models.kai_config import KaiConfig, KaiConfigIncidentStorePostgreSQLArgs -from kai.report import Report -from kai.service.incident_store.incident_store import ( - Application, - IncidentStore, - Solution, - load_reports_from_directory, -) - -# have a logical category of un-mutated reports and one with the mutations - - -class Base(DeclarativeBase): - type_annotation_map = { - dict[str, Any]: JSON().with_variant(JSONB(), "postgresql"), - list[str]: JSON().with_variant(JSONB(), "postgresql"), - } - - -class PSQLUnmodifiedReport(Base): - __tablename__ = "unmodified_reports" - - application_name: Mapped[str] = mapped_column(primary_key=True) - generated_at: Mapped[datetime.datetime] = mapped_column( - DateTime(), server_default=func.now(), primary_key=True - ) - report: Mapped[dict[str, Any]] - - -class ViolationCategory(enum.Enum): - potential = "potential" - optional = "optional" - mandatory = "mandatory" - - -class PSQLApplication(Base): - __tablename__ = "applications" - - application_name: Mapped[str] = mapped_column(primary_key=True) - - repo_uri_origin: Mapped[str] - repo_uri_local: Mapped[str] - current_branch: Mapped[str] - current_commit: Mapped[str] - generated_at: Mapped[datetime.datetime] - - incidents: Mapped[list["PSQLIncident"]] = relationship( - back_populates="application", cascade="all, delete-orphan" - ) - - -class PSQLRuleset(Base): - __tablename__ = "rulesets" - - ruleset_name: Mapped[str] = mapped_column(primary_key=True) - - tags: Mapped[list[str]] - - violations: Mapped[list["PSQLViolation"]] = relationship( - back_populates="ruleset", cascade="all, delete-orphan" - ) - - -class PSQLViolation(Base): - __tablename__ = "violations" - - violation_name: Mapped[str] = mapped_column(primary_key=True) - ruleset_name: Mapped[int] = mapped_column( - ForeignKey("rulesets.ruleset_name"), primary_key=True - ) - - category: Mapped[ViolationCategory] - labels: Mapped[list[str]] - - ruleset: Mapped[PSQLRuleset] = relationship(back_populates="violations") - incidents: Mapped[list["PSQLIncident"]] = relationship( - back_populates="violation", cascade="all, delete-orphan" - ) - - -class PSQLAcceptedSolution(Base): - __tablename__ = "accepted_solutions" - - solution_id: Mapped[int] = mapped_column(primary_key=True) - - generated_at: Mapped[datetime.datetime] = mapped_column( - DateTime(), server_default=func.now() - ) - solution_big_diff: Mapped[str] - solution_small_diff: Mapped[str] - solution_original_code: Mapped[str] - solution_updated_code: Mapped[str] - - incidents: Mapped[list["PSQLIncident"]] = relationship( - back_populates="solution", cascade="all, delete-orphan" - ) - - def __repr__(self): - return f"PSQLAcceptedSolution(solution_id={self.solution_id}, generated_at={self.generated_at}, solution_big_diff={self.solution_big_diff:.10}, solution_small_diff={self.solution_small_diff:.10}, solution_original_code={self.solution_original_code:.10}, solution_updated_code={self.solution_updated_code:.10})" - - -class PSQLIncident(Base): - __tablename__ = "incidents" - - incident_id: Mapped[int] = mapped_column(primary_key=True) - - violation_name = Column(String) - ruleset_name = Column(String) - application_name: Mapped[str] = mapped_column( - ForeignKey("applications.application_name") - ) - incident_uri: Mapped[str] - incident_snip: Mapped[str] - incident_line: Mapped[int] - incident_variables: Mapped[dict[str, Any]] - solution_id: Mapped[Optional[str]] = mapped_column( - ForeignKey("accepted_solutions.solution_id") - ) - - __table_args__ = ( - ForeignKeyConstraint( - [violation_name, ruleset_name], - [PSQLViolation.violation_name, PSQLViolation.ruleset_name], - ), - {}, - ) - - violation: Mapped[PSQLViolation] = relationship(back_populates="incidents") - application: Mapped[PSQLApplication] = relationship(back_populates="incidents") - solution: Mapped[PSQLAcceptedSolution] = relationship(back_populates="incidents") - - def __repr__(self) -> str: - return f"PSQLIncident(violation_name={self.violation_name}, ruleset_name={self.ruleset_name}, application_name={self.application_name}, incident_uri={self.incident_uri}, incident_snip={self.incident_snip:.10}, incident_line={self.incident_line}, incident_variables={self.incident_variables}, solution_id={self.solution_id})" - - -# def dump(sql, *multiparams, **params): -# print(sql.compile(dialect=engine.dialect)) - -# engine = create_engine('postgresql://', strategy='mock', executor=dump) -# Base.metadata.create_all(engine, checkfirst=False) - -# exit() +from kai.model_provider import ModelProvider +from kai.models.kai_config import KaiConfigIncidentStorePostgreSQLArgs +from kai.service.incident_store.incident_store import IncidentStore, SQLIncident, cmd class PSQLIncidentStore(IncidentStore): - def __init__(self, args: KaiConfigIncidentStorePostgreSQLArgs): - self.engine = create_engine( - f"postgresql://{args.user}:{args.password}@{args.host}:5432/{args.database}", - client_encoding="utf8", - ) - - def create_tables(self): - Base.metadata.create_all(self.engine) - - def delete_store(self): - Base.metadata.drop_all(self.engine) - self.create_tables() - - def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: - """ - Returns: (number_new_incidents, number_unsolved_incidents, - number_solved_incidents): tuple[int, int, int] - """ - # FIXME: Only does stuff within the same application. Maybe fixed? - - # NEW: Store whole report in table - # - if we get the same report again, we should skip adding it. Have some identifier - # - But should still check incidents. - - # - have two tables `unsolved_incidents` and `solved_incidents` - - # Iterate through all incidents in the report - # - change so theres an identified like "commit application ruleset violation" - - # create entries if not exists - # reference the old-new matrix - # old - # | NO | YES - # --------|--------+----------------------------- - # new NO | - | update (SOLVED, embeddings) - # YES | insert | update (line number, etc...) - - repo_path = unquote(urlparse(app.repo_uri_local).path) - repo = Repo(repo_path) - old_commit: str - new_commit = app.current_commit - - number_new_incidents = 0 - number_unsolved_incidents = 0 - number_solved_incidents = 0 - - with Session(self.engine) as session: - incidents_temp: list[PSQLIncident] = [] - - select_application_stmt = select(PSQLApplication).where( - PSQLApplication.application_name == app.application_name - ) - - application = session.scalars(select_application_stmt).first() - - if application is None: - application = PSQLApplication( - application_name=app.application_name, - repo_uri_origin=app.repo_uri_origin, - repo_uri_local=app.repo_uri_local, - current_branch=app.current_branch, - current_commit=app.current_commit, - generated_at=app.generated_at, - ) - session.add(application) - session.commit() - - # TODO: Determine if we want to have this check - # if application.generated_at >= app.generated_at: - # return 0, 0, 0 - - old_commit = application.current_commit - - report_dict = dict(report) - - for ruleset_name, ruleset_dict in report_dict.items(): - select_ruleset_stmt = select(PSQLRuleset).where( - PSQLRuleset.ruleset_name == ruleset_name - ) - - ruleset = session.scalars(select_ruleset_stmt).first() - - if ruleset is None: - ruleset = PSQLRuleset( - ruleset_name=ruleset_name, - tags=ruleset_dict.get("tags", []), - ) - session.add(ruleset) - session.commit() - - for violation_name, violation_dict in ruleset_dict.get( - "violations", {} - ).items(): - select_violation_stmt = ( - select(PSQLViolation) - .where(PSQLViolation.violation_name == violation_name) - .where(PSQLViolation.ruleset_name == ruleset.ruleset_name) - ) - - violation = session.scalars(select_violation_stmt).first() - - if violation is None: - violation = PSQLViolation( - violation_name=violation_name, - ruleset_name=ruleset.ruleset_name, - category=violation_dict.get("category", "potential"), - labels=violation_dict.get("labels", []), - ) - session.add(violation) - session.commit() - - for incident in violation_dict.get("incidents", []): - incidents_temp.append( - PSQLIncident( - violation_name=violation.violation_name, - ruleset_name=ruleset.ruleset_name, - application_name=application.application_name, - incident_uri=incident.get("uri", ""), - incident_snip=incident.get("codeSnip", ""), - incident_line=incident.get("lineNumber", 0), - incident_variables=incident.get("variables", {}), - ) - ) - - # incidents_temp - incidents - new_incidents = set(incidents_temp) - set(application.incidents) - number_new_incidents = len(new_incidents) - - for new_incident in new_incidents: - session.add(new_incident) - - session.commit() - - # incidents `intersect` incidents_temp - unsolved_incidents = set(application.incidents).intersection(incidents_temp) - number_unsolved_incidents = len(unsolved_incidents) - - # incidents - incidents_temp - solved_incidents = set(application.incidents) - set(incidents_temp) - number_solved_incidents = len(solved_incidents) - KAI_LOG.debug(f"Number of solved incidents: {len(solved_incidents)}") - # KAI_LOG.debug(f"{solved_incidents=}") - - for solved_incident in solved_incidents: - file_path = os.path.join( - repo_path, - # NOTE: When retrieving uris from the report, some of them - # had "/tmp/source-code/" as their root path. Unsure where - # it originates from. - unquote(urlparse(solved_incident.incident_uri).path).removeprefix( - "/tmp/source-code/" # trunk-ignore(bandit/B108) - ), - ) - - # NOTE: The `big_diff` functionality is currently disabled - - # big_diff: str = repo.git.diff(old_commit, new_commit).encode('utf-8', errors="ignore").decode() - big_diff = "" - - # TODO: Some of the sample repos have invalid utf-8 characters, - # thus the encode-then-decode hack. Not very performant, there's - # probably a better way to handle this. - - try: - original_code = ( - repo.git.show(f"{old_commit}:{file_path}") - .encode("utf-8", errors="ignore") - .decode() - ) - except Exception: - original_code = "" - - try: - updated_code = ( - repo.git.show(f"{new_commit}:{file_path}") - .encode("utf-8", errors="ignore") - .decode() - ) - except Exception: - updated_code = "" - - small_diff = ( - repo.git.diff(old_commit, new_commit, "--", file_path) - .encode("utf-8", errors="ignore") - .decode() - ) - - # TODO: Strings must be utf-8 encodable, so I'm removing the `big_diff` functionality for now - solved_incident.solution = PSQLAcceptedSolution( - generated_at=app.generated_at, - solution_big_diff=big_diff, - solution_small_diff=small_diff, - solution_original_code=original_code, - solution_updated_code=updated_code, - ) - - session.commit() - - application.repo_uri_origin = app.repo_uri_origin - application.repo_uri_local = app.repo_uri_local - application.current_branch = app.current_branch - application.current_commit = app.current_commit - application.generated_at = app.generated_at - - session.commit() - - unmodified_report = PSQLUnmodifiedReport( - application_name=app.application_name, - generated_at=app.generated_at, - report=report_dict, - ) - - # Upsert the unmodified report - session.merge(unmodified_report) - session.commit() - - return number_new_incidents, number_unsolved_incidents, number_solved_incidents - - def find_solutions( - self, - ruleset_name: str, - violation_name: str, - incident_variables: dict, - incident_snip: str | None = None, - ) -> list[Solution]: - if incident_snip is None: - incident_snip = "" - - with Session(self.engine) as session: - select_violation_stmt = ( - select(PSQLViolation) - .where(PSQLViolation.violation_name == violation_name) - .where(PSQLViolation.ruleset_name == ruleset_name) - ) - - violation = session.scalars(select_violation_stmt).first() - - if violation is None: - return [] - - select_incidents_with_solutions_stmt = ( - select(PSQLIncident) - .where(PSQLIncident.violation_name == violation.violation_name) - .where(PSQLIncident.ruleset_name == violation.ruleset_name) - .where(PSQLIncident.solution_id.isnot(None)) - .where(PSQLIncident.incident_variables.op("<@")(incident_variables)) - .where(PSQLIncident.incident_variables.op("@>")(incident_variables)) + def __init__( + self, args: KaiConfigIncidentStorePostgreSQLArgs, model_provider: ModelProvider + ): + if args.connection_string: + self.engine = create_engine(args.connection_string) + else: + self.engine = create_engine( + f"postgresql://{args.user}:{args.password}@{args.host}:5432/{args.database}", + client_encoding="utf8", ) - result: list[Solution] = [] - for incident in session.execute(select_incidents_with_solutions_stmt): - select_accepted_solution_stmt = select(PSQLAcceptedSolution).where( - PSQLAcceptedSolution.solution_id == incident.solution_id - ) - - accepted_solution = session.scalars( - select_accepted_solution_stmt - ).first() - - result.append( - Solution( - uri=incident.incident_uri, - file_diff=accepted_solution.solution_small_diff, - repo_diff=accepted_solution.solution_big_diff, - ) - ) - - return result - - -def main(): - KAI_LOG.setLevel("debug".upper()) - - parser = argparse.ArgumentParser(description="Process some parameters.") - parser.add_argument( - "--config_filepath", - type=str, - default="../../config.toml", - help="Path to the config file.", - ) - parser.add_argument( - "--drop_tables", type=str, default="False", help="Whether to drop tables." - ) - parser.add_argument( - "--analysis_dir_path", - type=str, - default="../../samples/analysis_reports", - help="Path to analysis reports folder", - ) + self.model_provider = model_provider - args = parser.parse_args() - - config = KaiConfig.model_validate_filepath(args.config_filepath) - - if config.incident_store.provider != "postgresql": - raise Exception("This script only works with PostgreSQL incident store.") - - incident_store = PSQLIncidentStore(config.incident_store.args) - - if args.drop_tables: - incident_store.delete_store() - - load_reports_from_directory(incident_store, args.analysis_dir_path) + def json_exactly_equal(self, json_dict: dict): + return and_( + SQLIncident.incident_variables.op("<@")(json_dict), + SQLIncident.incident_variables.op("@>")(json_dict), + ) if __name__ == "__main__": - main() + cmd("postgresql") diff --git a/kai/service/incident_store/sqlite.py b/kai/service/incident_store/sqlite.py new file mode 100644 index 00000000..cf9aa931 --- /dev/null +++ b/kai/service/incident_store/sqlite.py @@ -0,0 +1,37 @@ +from sqlalchemy import bindparam, create_engine, text + +from kai.model_provider import ModelProvider +from kai.models.kai_config import KaiConfigIncidentStoreSQLiteArgs +from kai.service.incident_store.incident_store import IncidentStore + + +class SQLiteIncidentStore(IncidentStore): + def __init__( + self, args: KaiConfigIncidentStoreSQLiteArgs, model_provider: ModelProvider + ): + if args.connection_string: + self.engine = create_engine(args.connection_string) + else: + self.engine = create_engine( + f"sqlite://{args.user}:{args.password}@{args.host}:5432/{args.database}", + client_encoding="utf8", + ) + + self.model_provider = model_provider + + def json_exactly_equal(self, json_dict: dict): + return text( + """ + ( + SELECT key, value + FROM json_tree(SQLIncident.incident_variables) + WHERE type != 'object' + ORDER BY key, value + ) = ( + SELECT key, value + FROM json_tree(:json_dict) + WHERE type != 'object' + ORDER BY key, value + ) + """ + ).bindparams(bindparam("json_dict", json_dict))