diff --git a/kai/data/templates/solution_handling/before_and_after.jinja b/kai/data/templates/solution_handling/before_and_after.jinja index 2b7b34bc..cd5f5911 100644 --- a/kai/data/templates/solution_handling/before_and_after.jinja +++ b/kai/data/templates/solution_handling/before_and_after.jinja @@ -1,9 +1,9 @@ -Solution before changes: +Solved example before changes: ``` {{ solution.original_code }} ``` -Solution after changes: +Solved example after changes: ``` {{ solution.updated_code }} ``` diff --git a/kai/data/templates/solution_handling/diff_only.jinja b/kai/data/templates/solution_handling/diff_only.jinja index 72b5ddce..d3d47b72 100644 --- a/kai/data/templates/solution_handling/diff_only.jinja +++ b/kai/data/templates/solution_handling/diff_only.jinja @@ -1,3 +1,3 @@ -Solution diff: +Solved example diff: ```diff {{ solution.file_diff }} diff --git a/kai/data/templates/solution_handling/llm_summary.jinja b/kai/data/templates/solution_handling/llm_summary.jinja index 170e7afd..1afcb8de 100644 --- a/kai/data/templates/solution_handling/llm_summary.jinja +++ b/kai/data/templates/solution_handling/llm_summary.jinja @@ -1,3 +1,3 @@ -Summary of changes for solution: +Summary of changes for solved example: {{ solution.llm_summary }} diff --git a/kai/models/util.py b/kai/models/util.py index 05475e7e..39b8cf5c 100644 --- a/kai/models/util.py +++ b/kai/models/util.py @@ -12,7 +12,11 @@ # These are known unique variables that can be included by incidents # They would prevent matches that we actually want, so we filter them # before adding to the database or searching -FILTERED_INCIDENT_VARS = ("file", "package") +FILTERED_INCIDENT_VARS = [ + "file", # Java, URI of the offending file + "package", # Java, shows the package + "name", # Java, shows the name of the method that caused the incident +] def remove_known_prefixes(path: str) -> str: diff --git a/kai/routes/load_analysis_report.py b/kai/routes/load_analysis_report.py index 62bc45b3..256cfab7 100644 --- a/kai/routes/load_analysis_report.py +++ b/kai/routes/load_analysis_report.py @@ -21,8 +21,9 @@ class PostLoadAnalysisReportApplication(BaseModel): class PostLoadAnalysisReportParams(BaseModel): - path_to_report: str application: PostLoadAnalysisReportApplication + report_data: dict | list[dict] + report_id: str @to_route("post", "/load_analysis_report") @@ -30,7 +31,7 @@ async def post_load_analysis_report(request: Request): params = PostLoadAnalysisReportParams.model_validate(await request.json()) application = Application(**params.application.model_dump()) - report = Report.load_report_from_file(params.path_to_report) + report = Report(params.report_data, params.report_id) count = request.app["kai_application"].incident_store.load_report( application, report diff --git a/kai/service/incident_store/incident_store.py b/kai/service/incident_store/incident_store.py index 4ce0feb0..6cde43d0 100644 --- a/kai/service/incident_store/incident_store.py +++ b/kai/service/incident_store/incident_store.py @@ -181,6 +181,8 @@ def __init__( self.solution_detector = solution_detector self.solution_producer = solution_producer + self.create_tables() # This is a no-op if the tables already exist + def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: """ Load incidents from a report and given application object. Returns a @@ -273,6 +275,7 @@ def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: session.commit() for incident in violation_obj.incidents: + filtered_vars = filter_incident_vars(incident.variables) report_incidents.append( SQLIncident( violation_name=violation.violation_name, @@ -281,7 +284,7 @@ def load_report(self, app: Application, report: Report) -> tuple[int, int, int]: incident_uri=incident.uri, incident_snip=incident.code_snip, incident_line=incident.line_number, - incident_variables=deep_sort(incident.variables), + incident_variables=deep_sort(filtered_vars), incident_message=incident.message, ) ) @@ -401,9 +404,7 @@ def find_solutions( ) result: list[Solution] = [] - for incident in session.execute( - select_incidents_with_solutions_stmt - ).scalars(): + for incident in session.scalars(select_incidents_with_solutions_stmt).all(): select_accepted_solution_stmt = select(SQLAcceptedSolution).where( SQLAcceptedSolution.solution_id == incident.solution_id ) @@ -415,9 +416,18 @@ def find_solutions( processed_solution = self.solution_producer.post_process_one( incident, accepted_solution.solution ) + + # TODO: This first line doesn't work for some reason. The second + # line is a hack to get around it. accepted_solution.solution = processed_solution + session.query(SQLAcceptedSolution).filter( + SQLAcceptedSolution.solution_id == incident.solution_id + ).update({"solution": processed_solution}) + result.append(processed_solution) + session.commit() + session.commit() return result diff --git a/kai/service/incident_store/sql_types.py b/kai/service/incident_store/sql_types.py index c04d76cf..bb1bf7ac 100644 --- a/kai/service/incident_store/sql_types.py +++ b/kai/service/incident_store/sql_types.py @@ -23,7 +23,7 @@ class SQLSolutionType(TypeDecorator): impl = VARCHAR - cache_ok = True + cache_ok = False def process_bind_param(self, value: Optional[Solution], dialect: Dialect): # Into the db @@ -148,7 +148,7 @@ class SQLIncident(SQLBase): incident_uri: Mapped[str] incident_message: Mapped[str] incident_snip: Mapped[str] - incident_line: Mapped[int] + incident_line: Mapped[int] # 0-indexed! incident_variables: Mapped[dict[str, Any]] solution_id: Mapped[Optional[str]] = mapped_column( ForeignKey("accepted_solutions.solution_id") diff --git a/kai/service/kai_application/kai_application.py b/kai/service/kai_application/kai_application.py index 2815f0d4..eee9e2ef 100644 --- a/kai/service/kai_application/kai_application.py +++ b/kai/service/kai_application/kai_application.py @@ -148,9 +148,10 @@ def get_incident_solutions_for_file( ) if len(solutions) != 0: - pb_incident["solution_str"] = self.solution_consumer( - solutions[0] - ) + solution_str = self.solution_consumer(solutions[0]) + + if len(solution_str) != 0: + pb_incident["solution_str"] = solution_str pb_vars = { "src_file_name": file_name, diff --git a/kai/service/llm_interfacing/model_provider.py b/kai/service/llm_interfacing/model_provider.py index be69fc07..31fd4438 100644 --- a/kai/service/llm_interfacing/model_provider.py +++ b/kai/service/llm_interfacing/model_provider.py @@ -100,7 +100,22 @@ def __init__(self, config: KaiConfigModels): model_class = FakeListChatModel defaults = { - "responses": ["Default LLM response."], + "responses": [ + "## Reasoning\n" + "\n" + "Default reasoning.\n" + "\n" + "## Updated File\n" + "\n" + "```\n" + "Default updated file.\n" + "```\n" + "\n" + "## Additional Information\n" + "\n" + "Default additional information.\n" + "\n" + ], "sleep": None, } diff --git a/kai/service/solution_handling/consumption.py b/kai/service/solution_handling/consumption.py index 48e63482..4569767d 100644 --- a/kai/service/solution_handling/consumption.py +++ b/kai/service/solution_handling/consumption.py @@ -42,6 +42,9 @@ def solution_consumer_before_and_after(solution: Solution) -> str: def solution_consumer_llm_summary(solution: Solution) -> str: + if solution.llm_summary is None: + return "" + return ( __create_jinja_env().get_template("llm_summary.jinja").render(solution=solution) ) diff --git a/kai/service/solution_handling/detection.py b/kai/service/solution_handling/detection.py index 00a068af..e2726a59 100644 --- a/kai/service/solution_handling/detection.py +++ b/kai/service/solution_handling/detection.py @@ -1,10 +1,19 @@ import json +import os +from collections import defaultdict from dataclasses import dataclass -from typing import Callable +from typing import Callable, cast +from urllib.parse import unquote, urlparse +import tree_sitter as ts +import tree_sitter_java from git import Repo +from sequoia_diff import loaders +from sequoia_diff.matching import generate_mappings +from sequoia_diff.models import Node from kai.models.kai_config import SolutionDetectorKind +from kai.models.util import remove_known_prefixes from kai.service.incident_store.sql_types import SQLIncident @@ -35,6 +44,9 @@ class SolutionDetectorResult: def naive_hash(x: SQLIncident) -> int: + """ + Returns a hash of the incident that is used for naive equality checking. + """ return hash( ( x.violation_name, @@ -48,6 +60,13 @@ def naive_hash(x: SQLIncident) -> int: def solution_detection_naive(ctx: SolutionDetectorContext) -> SolutionDetectorResult: + """ + The naive solution detection algorithm is the simplest one. It just checks + if a new incident is exactly the same as an old incident. If it is, then the + incident is *unsolved*. Otherwise, it's new. Any old incidents that are not + matched are considered solved. + """ + result = SolutionDetectorResult([], [], []) updating_set: dict[int, SQLIncident] = {naive_hash(x): x for x in ctx.old_incidents} @@ -65,12 +84,185 @@ def solution_detection_naive(ctx: SolutionDetectorContext) -> SolutionDetectorRe return result +def line_match_hash(x: SQLIncident) -> int: + """ + Returns a hash of the incident that is used for line matching. + """ + return hash( + ( + x.violation_name, + x.ruleset_name, + x.application_name, + x.incident_uri, + json.dumps(x.incident_variables, sort_keys=True), + ) + ) + + +def node_with_tightest_bounds(node: Node, start_byte: int, end_byte: int) -> Node: + """ + Find the node with the tightest bounds that still contains the given byte + range. + """ + + best = node + while True: + best.orig_node = cast(ts.Node, best.orig_node) + another_iteration = False + + for child in best.children: + ts_node = cast(ts.Node, child.orig_node) + + if ( + ts_node.start_byte > start_byte + or ts_node.end_byte < end_byte + or ts_node.start_byte < best.orig_node.start_byte + or ts_node.end_byte > best.orig_node.end_byte + ): + continue + best = child + another_iteration = True + + if not another_iteration: + break + + return best + + def solution_detection_line_match( ctx: SolutionDetectorContext, ) -> SolutionDetectorResult: - # TODO: Implement line matching solution detection with sequoia-diff or - # gumtree - return solution_detection_naive(ctx) + """ + The line match algorithm is trying to find incidents that still exist after + a change by making the assumption that if the code exists in the changed + file somewhere, then the incident is not solved. Note that this should be + irrelevant of line number. + + 1. Filter out the exact matches. + 2. Get a mapping between the two ASTs. (A each item in a mapping is a node + in the source tree and the corresponding node in the destination tree. + The algorithm for that is in the sequoia-diff library). + 3. Get the smallest node that still contains the line under question. + 4. Check if the mapping contains the node. If it does, the incident is + unsolved. If not, it's solved + """ + # TODO: Support multiple languages + ts_language = ts.Language(tree_sitter_java.language()) + parser = ts.Parser(ts_language) + + result = SolutionDetectorResult([], [], []) + # new_incidents = ctx.new_incidents.copy() + new_incidents = [x for x in ctx.new_incidents] + + # Map the old incidents to their hashes for quick equality lookup. + + naive_old_incidents: dict[int, SQLIncident] = { + naive_hash(x): x for x in ctx.old_incidents + } + + # Filter the exact matches + + i = 0 + while i < len(new_incidents): + incident = new_incidents[i] + incident_hash = naive_hash(incident) + + if incident_hash in naive_old_incidents: + result.unsolved.append(naive_old_incidents.pop(incident_hash)) + new_incidents.pop(i) + else: + i += 1 + + # result.unsolved now contains all exact matches and naive_old_incidents + # contains all non-exact matches. + + # Filter incidents whose line numbers match and whose line contents may have + # just been moved. + + # NOTE: We use a multi-map here because there may be more than one match. We + # may also need to switch the naive matching to use a multi-map in the + # future. + + line_match_old_incidents: dict[int, set[SQLIncident]] = defaultdict(set) + for x in naive_old_incidents.values(): + line_match_old_incidents[line_match_hash(x)].add(x) + + # Check each remaining incident in new_incidents + + for incident in new_incidents: + incident_line_match_hash = line_match_hash(incident) + + # Check if the incident is in the remaining old incidents. If not, then + # it's a new incident. + + if incident_line_match_hash not in line_match_old_incidents: + result.new.append(incident) + continue + if len(line_match_old_incidents[incident_line_match_hash]) == 0: + result.new.append(incident) + del line_match_old_incidents[incident_line_match_hash] + continue + + # Construct the trees for the old and new files + + # NOTE: Both file paths should be the same, but just in case we might + # want to use the old file path. + + file_path = os.path.join( + cast(str, ctx.repo.working_tree_dir), + remove_known_prefixes(unquote(urlparse(incident.incident_uri).path)), + ) + + # TODO: See if we should cache these files/trees for performance + + old_file: str = ctx.repo.git.show(f"{ctx.old_commit}:{file_path}") + new_file: str = ctx.repo.git.show(f"{ctx.new_commit}:{file_path}") + + old_tree = parser.parse(bytes(old_file, "utf-8")) + new_tree = parser.parse(bytes(new_file, "utf-8")) + + old_node = loaders.from_tree_sitter_tree(old_tree, "java") + new_node = loaders.from_tree_sitter_tree(new_tree, "java") + + # Get the byte offsets for the incident line + + old_line_start_byte = 0 + old_line_end_byte = 0 + for _ in range(incident.incident_line + 1): + old_line_start_byte = old_file.find("\n", old_line_start_byte) + 1 + old_line_end_byte = old_file.find("\n", old_line_start_byte) + + # Get the node with the tightest bounds + + best = node_with_tightest_bounds( + old_node, old_line_start_byte, old_line_end_byte + ) + + # Get the mappings from old_tree to new_tree + + mappings = generate_mappings(old_node, new_node) + + if best not in mappings.src_to_dst: + result.new.append(incident) + continue + + # NOTE: Right now we're just assuming that if the mapping algorithm + # successfully finds a mapping, then the incident is unsolved. This is a + # very naive approach and should be improved in the future. Some static + # analysis may be required. + + old_incident = line_match_old_incidents[incident_line_match_hash].pop() + old_incident.incident_line = incident.incident_line + result.unsolved.append(old_incident) + + # These are the incidents that weren't matched to any incident in + # new_incidents, meaning they were solved. + + for incident_set in line_match_old_incidents.values(): + if len(incident_set) > 0: + result.solved.extend(incident_set) + + return result def solution_detection_factory( diff --git a/kai/service/solution_handling/production.py b/kai/service/solution_handling/production.py index b2453cdf..34bee2f2 100644 --- a/kai/service/solution_handling/production.py +++ b/kai/service/solution_handling/production.py @@ -65,9 +65,12 @@ class SolutionProducerTextOnly(SolutionProducer): def produce_one( self, incident: SQLIncident, repo: Repo, old_commit: str, new_commit: str ) -> Solution: + local_file_path = remove_known_prefixes( + unquote(urlparse(incident.incident_uri).path) + ) file_path = os.path.join( repo.working_tree_dir, - remove_known_prefixes(unquote(urlparse(incident.incident_uri).path)), + local_file_path, ) # NOTE: `repo_diff` functionality is not implemented @@ -77,7 +80,7 @@ def produce_one( # probably a better way to handle this. try: original_code = ( - repo.git.show(f"{new_commit}:{file_path}") + repo.git.show(f"{old_commit}:{local_file_path}") .encode("utf-8") .decode("utf-8") ) @@ -86,7 +89,7 @@ def produce_one( try: updated_code = ( - repo.git.show(f"{new_commit}:{file_path}") + repo.git.show(f"{new_commit}:{local_file_path}") .encode("utf-8") .decode("utf-8") ) @@ -113,13 +116,12 @@ def post_process_one(self, incident: SQLIncident, solution: Solution) -> Solutio class SolutionProducerLLMLazy(SolutionProducer): def __init__(self, model_provider: ModelProvider): self.model_provider = model_provider + self.text_only = SolutionProducerTextOnly() def produce_one( self, incident: SQLIncident, repo: Repo, old_commit: str, new_commit: str ) -> Solution: - solution = SolutionProducerTextOnly().produce_one( - incident, repo, old_commit, new_commit - ) + solution = self.text_only.produce_one(incident, repo, old_commit, new_commit) solution.llm_summary_generated = False diff --git a/kai/service/solution_handling/test_detection.py b/kai/service/solution_handling/test_detection.py index 7c7975f8..54c36356 100644 --- a/kai/service/solution_handling/test_detection.py +++ b/kai/service/solution_handling/test_detection.py @@ -1,8 +1,17 @@ +import os import unittest +from unittest.mock import MagicMock, create_autospec +import tree_sitter as ts +import yaml +from sequoia_diff.models import Node + +from kai.constants import PATH_TEST_DATA from kai.service.incident_store.sql_types import SQLIncident from kai.service.solution_handling.detection import ( SolutionDetectorContext, + node_with_tightest_bounds, + solution_detection_line_match, solution_detection_naive, ) @@ -68,7 +77,7 @@ def test_naive_simple(self): ] result = solution_detection_naive( - SolutionDetectorContext(db_incidents, report_incidents, None, None, None) + SolutionDetectorContext(db_incidents, report_incidents, MagicMock(), "", "") ) self.assertTrue(len(result.new) == 1) @@ -80,6 +89,160 @@ def test_naive_simple(self): self.assertTrue(report_incidents[0] != result.unsolved[0]) self.assertTrue(report_incidents[1] == result.new[0]) - @unittest.skip("Not implemented") - def test_line_match(self): - pass + def test_line_match_simple(self): + def local_read(*args) -> str: + sections = args[0].split(":") + stuff = sections[0] + file_path = sections[1] + dirname = os.path.dirname(os.path.realpath(file_path)) + basename = os.path.basename(file_path) + + file_name = os.path.join(dirname, f"{stuff}_{basename}") + + with open(file_name, "r") as file: + return file.read() + + def local_yaml(file_path: str) -> dict | list: + with open( + os.path.join( + PATH_TEST_DATA, + "test_detection", + "test_line_match_simple", + file_path, + ), + "r", + ) as file: + return yaml.safe_load(file) + + mock_repo = MagicMock() + mock_repo.git.show.side_effect = local_read + mock_repo.working_tree_dir = os.path.join( + PATH_TEST_DATA, "test_detection", "test_line_match_simple" + ) + + old_commit = "old" + new_commit = "new" + + # No incidents, no changes + + result = solution_detection_line_match( + SolutionDetectorContext([], [], mock_repo, old_commit, new_commit) + ) + + self.assertEqual(result.new, [], "Failed no incidents no changes") + self.assertEqual(result.unsolved, [], "Failed no incidents no changes") + self.assertEqual(result.solved, [], "Failed no incidents no changes") + + # Exact matches of incidents + + old_incidents = [SQLIncident(**x) for x in local_yaml("exact_matches.yaml")] + new_incidents = [SQLIncident(**x) for x in local_yaml("exact_matches.yaml")] + + result = solution_detection_line_match( + SolutionDetectorContext( + old_incidents, new_incidents, mock_repo, old_commit, new_commit + ) + ) + + self.assertEqual(result.new, [], "Failed exact matches") + self.assertEqual(result.unsolved, old_incidents, "Failed exact matches") + self.assertEqual(result.solved, [], "Failed exact matches") + + # Exact matches of incidents, with some new incidents + + old_incidents = [ + SQLIncident(**x) for x in local_yaml("old_added_incidents.yaml") + ] + new_incidents = [ + SQLIncident(**x) for x in local_yaml("new_added_incidents.yaml") + ] + + result = solution_detection_line_match( + SolutionDetectorContext( + old_incidents, new_incidents, mock_repo, old_commit, new_commit + ) + ) + + self.assertEqual(result.new, [new_incidents[1]], "Failed added incidents") + self.assertEqual(result.unsolved, [old_incidents[0]], "Failed added incidents") + self.assertEqual(result.solved, [], "Failed added incidents") + + # Adding whitespace + + old_incidents = [ + SQLIncident(**x) for x in local_yaml("old_added_whitespace.yaml") + ] + new_incidents = [ + SQLIncident(**x) for x in local_yaml("new_added_whitespace.yaml") + ] + + result = solution_detection_line_match( + SolutionDetectorContext( + old_incidents, new_incidents, mock_repo, old_commit, new_commit + ) + ) + + self.assertEqual(result.new, [], "Failed added whitespace") + self.assertEqual(result.unsolved, old_incidents, "Failed added whitespace") + self.assertEqual(result.solved, [], "Failed added whitespace") + + +class TestNodeWithTightestBounds(unittest.TestCase): + def setUp(self): + # Mocking ts.Node + self.mock_ts_node = create_autospec(ts.Node) + + def create_node(self, start_byte, end_byte, children=None) -> Node: + orig_node = self.mock_ts_node() + orig_node.start_byte = start_byte + orig_node.end_byte = end_byte + node = Node( + type="mock_type", label="mock_label", orig_node=orig_node, children=children + ) + return node + + def test_single_node_within_bounds(self): + node = self.create_node(10, 20) + result = node_with_tightest_bounds(node, 12, 18) + self.assertEqual(result, node) + + def test_child_node_within_bounds(self): + parent_node = self.create_node(10, 30) + child_node = self.create_node(15, 25) + parent_node.children_append(child_node) + + result = node_with_tightest_bounds(parent_node, 16, 24) + self.assertEqual(result, child_node) + + def test_no_child_node_within_bounds(self): + parent_node = self.create_node(5, 35) + child_node = self.create_node(10, 30) + parent_node.children_append(child_node) + + result = node_with_tightest_bounds(parent_node, 16, 24) + self.assertEqual(result, child_node) + + def test_multiple_children_one_within_bounds(self): + parent_node = self.create_node(10, 50) + child_node1 = self.create_node(15, 35) + child_node2 = self.create_node(20, 30) + parent_node.children_append(child_node1) + parent_node.children_append(child_node2) + + result = node_with_tightest_bounds(parent_node, 21, 29) + self.assertEqual(result, child_node2) + + def test_multiple_nested_children(self): + parent_node = self.create_node(10, 60) + child_node1 = self.create_node(15, 55) + child_node2 = self.create_node(20, 50) + child_node3 = self.create_node(25, 45) + child_node4 = self.create_node(30, 40) + + parent_node.children_append(child_node1) + child_node1.children_append(child_node2) + child_node2.children_append(child_node3) + child_node3.children_append(child_node4) + + result = node_with_tightest_bounds(parent_node, 31, 39) + self.assertEqual(result, child_node4) diff --git a/requirements.in b/requirements.in index f167bb60..36c871fd 100644 --- a/requirements.in +++ b/requirements.in @@ -1,6 +1,6 @@ # To generate a new requirements.txt: # $ pip install pip-tools -# $ pip-compile --allow-unsafe +# $ pip-compile --allow-unsafe > requirements.txt # To view requirements.txt's dependencies in a tree format: # $ pip install pipdeptree @@ -25,9 +25,11 @@ Jinja2==3.1.4 langchain==0.2.11 langchain-community==0.2.10 langchain-openai==0.1.17 -langchain-experimental==0.0.63 +langchain-experimental==0.0.64 gunicorn==22.0.0 tree-sitter==0.22.3 +tree-sitter-java==0.21.0 +sequoia-diff==0.0.8 # Fabian's fork has changes that fix some async issues in the real vcrpy that # are yet to be accepted vcrpy @ git+https://github.com/fabianvf/vcrpy.git@httpx-async-threadpool diff --git a/requirements.txt b/requirements.txt index 8e723f24..eb11df37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,8 +20,6 @@ anyio==4.4.0 # httpx # jupyter-server # openai -appnope==0.1.4 - # via ipykernel argon2-cffi==23.1.0 # via jupyter-server argon2-cffi-bindings==21.2.0 @@ -40,12 +38,12 @@ async-lru==2.0.4 # via jupyterlab async-timeout==4.0.3 # via -r requirements.in -attrs==23.2.0 +attrs==24.2.0 # via # aiohttp # jsonschema # referencing -babel==2.15.0 +babel==2.16.0 # via jupyterlab-server beautifulsoup4==4.12.3 # via nbconvert @@ -53,7 +51,7 @@ bleach==6.1.0 # via nbconvert boto3==1.34.157 # via -r requirements.in -botocore==1.34.157 +botocore==1.34.162 # via # boto3 # s3transfer @@ -62,7 +60,7 @@ certifi==2024.7.4 # httpcore # httpx # requests -cffi==1.16.0 +cffi==1.17.0 # via argon2-cffi-bindings charset-normalizer==3.3.2 # via requests @@ -76,7 +74,7 @@ coverage==7.6.0 # via -r requirements.in dataclasses-json==0.6.7 # via langchain-community -debugpy==1.8.2 +debugpy==1.8.5 # via ipykernel decorator==5.1.1 # via ipython @@ -88,7 +86,7 @@ execnb==0.1.6 # via nbdev executing==2.0.1 # via stack-data -fastcore==1.5.54 +fastcore==1.7.1 # via # execnb # ghapi @@ -107,6 +105,8 @@ gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via -r requirements.in +greenlet==3.0.3 + # via sqlalchemy gunicorn==22.0.0 # via -r requirements.in h11==0.14.0 @@ -117,6 +117,7 @@ httpx==0.26.0 # via # ibm-generative-ai # jupyterlab + # langsmith # openai httpx-sse==0.3.1 # via ibm-generative-ai @@ -157,6 +158,8 @@ jinja2==3.1.4 # jupyterlab # jupyterlab-server # nbconvert +jiter==0.5.0 + # via openai jmespath==1.0.1 # via # boto3 @@ -231,20 +234,20 @@ langchain-community==0.2.10 # via # -r requirements.in # langchain-experimental -langchain-core==0.2.23 +langchain-core==0.2.33 # via # langchain # langchain-community # langchain-experimental # langchain-openai # langchain-text-splitters -langchain-experimental==0.0.63 +langchain-experimental==0.0.64 # via -r requirements.in langchain-openai==0.1.17 # via -r requirements.in langchain-text-splitters==0.2.2 # via langchain -langsmith==0.1.93 +langsmith==0.1.100 # via # langchain # langchain-community @@ -255,7 +258,7 @@ markupsafe==2.1.5 # via # jinja2 # nbconvert -marshmallow==3.21.3 +marshmallow==3.22.0 # via dataclasses-json matplotlib-inline==0.1.7 # via @@ -294,9 +297,9 @@ numpy==1.26.4 # via # langchain # langchain-community -openai==1.37.0 +openai==1.41.1 # via langchain-openai -orjson==3.10.6 +orjson==3.10.7 # via langsmith overrides==7.7.0 # via jupyter-server @@ -349,6 +352,7 @@ pydantic==2.8.2 # langsmith # openai # pydantic-settings + # sequoia-diff pydantic-core==2.20.1 # via pydantic pydantic-settings==2.3.4 @@ -381,7 +385,7 @@ pyyaml==6.0.1 # langchain-core # nbdev # vcrpy -pyzmq==26.0.3 +pyzmq==26.1.1 # via # ipykernel # jupyter-client @@ -415,7 +419,7 @@ rfc3986-validator==0.1.1 # via # jsonschema # jupyter-events -rpds-py==0.19.1 +rpds-py==0.20.0 # via # jsonschema # referencing @@ -423,6 +427,8 @@ s3transfer==0.10.2 # via boto3 send2trash==1.8.3 # via jupyter-server +sequoia-diff==0.0.8 + # via -r requirements.in six==1.16.0 # via # asttokens @@ -437,7 +443,7 @@ sniffio==1.3.1 # anyio # httpx # openai -soupsieve==2.5 +soupsieve==2.6 # via beautifulsoup4 sqlalchemy==2.0.22 # via @@ -467,7 +473,7 @@ tornado==6.4.1 # jupyterlab # notebook # terminado -tqdm==4.66.4 +tqdm==4.66.5 # via openai traitlets==5.14.3 # via @@ -487,6 +493,10 @@ traitlets==5.14.3 # nbformat # qtconsole tree-sitter==0.22.3 + # via + # -r requirements.in + # sequoia-diff +tree-sitter-java==0.21.0 # via -r requirements.in typer==0.9.0 # via -r requirements.in @@ -494,6 +504,7 @@ types-python-dateutil==2.9.0.20240316 # via arrow typing-extensions==4.12.2 # via + # langchain-core # openai # pydantic # pydantic-core @@ -513,11 +524,11 @@ urllib3==2.2.2 # vcrpy vcrpy @ git+https://github.com/fabianvf/vcrpy.git@httpx-async-threadpool # via -r requirements.in -watchdog==4.0.1 +watchdog==4.0.2 # via nbdev wcwidth==0.2.13 # via prompt-toolkit -webcolors==24.6.0 +webcolors==24.8.0 # via jsonschema webencodings==0.5.1 # via @@ -525,7 +536,7 @@ webencodings==0.5.1 # tinycss2 websocket-client==1.8.0 # via jupyter-server -wheel==0.43.0 +wheel==0.44.0 # via astunparse widgetsnbextension==4.0.11 # via ipywidgets @@ -539,5 +550,5 @@ yarl==1.9.4 # The following packages are considered to be unsafe in a requirements file: pip==24.2 # via ghapi -setuptools==72.1.0 +setuptools==73.0.1 # via jupyterlab diff --git a/tests/test_data/test_detection/test_line_match_simple/exact_matches.yaml b/tests/test_data/test_detection/test_line_match_simple/exact_matches.yaml new file mode 100644 index 00000000..920c06f5 --- /dev/null +++ b/tests/test_data/test_detection/test_line_match_simple/exact_matches.yaml @@ -0,0 +1,8 @@ +- violation_name: test_violation + ruleset_name: test_ruleset + application_name: test_application + incident_uri: exact_matches.java + incident_message: Test incident message + incident_snip: "" + incident_line: 4 + incident_variables: {} diff --git a/tests/test_data/test_detection/test_line_match_simple/new_added_incidents.java b/tests/test_data/test_detection/test_line_match_simple/new_added_incidents.java new file mode 100644 index 00000000..e6f8823b --- /dev/null +++ b/tests/test_data/test_detection/test_line_match_simple/new_added_incidents.java @@ -0,0 +1,11 @@ +package kai.example; + +class AddedIncidents { + public AddedIncidents() { + System.out.println("No change"); + } + + public void newMethod() { + System.out.println("New method"); + } +} \ No newline at end of file diff --git a/tests/test_data/test_detection/test_line_match_simple/new_added_incidents.yaml b/tests/test_data/test_detection/test_line_match_simple/new_added_incidents.yaml new file mode 100644 index 00000000..b7753d05 --- /dev/null +++ b/tests/test_data/test_detection/test_line_match_simple/new_added_incidents.yaml @@ -0,0 +1,17 @@ +- violation_name: test_violation + ruleset_name: test_ruleset + application_name: test_application + incident_uri: added_incidents.java + incident_message: Test incident message + incident_snip: "" + incident_line: 4 + incident_variables: {} + +- violation_name: test_violation + ruleset_name: test_ruleset + application_name: test_application + incident_uri: added_incidents.java + incident_message: Test incident message + incident_snip: "" + incident_line: 8 + incident_variables: {} diff --git a/tests/test_data/test_detection/test_line_match_simple/new_added_whitespace.java b/tests/test_data/test_detection/test_line_match_simple/new_added_whitespace.java new file mode 100644 index 00000000..04e2f391 --- /dev/null +++ b/tests/test_data/test_detection/test_line_match_simple/new_added_whitespace.java @@ -0,0 +1,18 @@ +package kai.example; + +class NoChange { + + + public NoChange() { + + System.out.println("No change"); + + + + + + } + + + +} \ No newline at end of file diff --git a/tests/test_data/test_detection/test_line_match_simple/new_added_whitespace.yaml b/tests/test_data/test_detection/test_line_match_simple/new_added_whitespace.yaml new file mode 100644 index 00000000..8d2061f4 --- /dev/null +++ b/tests/test_data/test_detection/test_line_match_simple/new_added_whitespace.yaml @@ -0,0 +1,8 @@ +- violation_name: test_violation + ruleset_name: test_ruleset + application_name: test_application + incident_uri: added_whitespace.java + incident_message: Test incident message + incident_snip: "" + incident_line: 8 + incident_variables: {} diff --git a/tests/test_data/test_detection/test_line_match_simple/new_exact_matches.java b/tests/test_data/test_detection/test_line_match_simple/new_exact_matches.java new file mode 100644 index 00000000..f9901034 --- /dev/null +++ b/tests/test_data/test_detection/test_line_match_simple/new_exact_matches.java @@ -0,0 +1,7 @@ +package kai.example; + +class NoChange { + public NoChange() { + System.out.println("No change"); + } +} \ No newline at end of file diff --git a/tests/test_data/test_detection/test_line_match_simple/old_added_incidents.java b/tests/test_data/test_detection/test_line_match_simple/old_added_incidents.java new file mode 100644 index 00000000..4509df33 --- /dev/null +++ b/tests/test_data/test_detection/test_line_match_simple/old_added_incidents.java @@ -0,0 +1,7 @@ +package kai.example; + +class AddedIncidents { + public AddedIncidents() { + System.out.println("No change"); + } +} \ No newline at end of file diff --git a/tests/test_data/test_detection/test_line_match_simple/old_added_incidents.yaml b/tests/test_data/test_detection/test_line_match_simple/old_added_incidents.yaml new file mode 100644 index 00000000..ea91aec1 --- /dev/null +++ b/tests/test_data/test_detection/test_line_match_simple/old_added_incidents.yaml @@ -0,0 +1,8 @@ +- violation_name: test_violation + ruleset_name: test_ruleset + application_name: test_application + incident_uri: added_incidents.java + incident_message: Test incident message + incident_snip: "" + incident_line: 4 + incident_variables: {} diff --git a/tests/test_data/test_detection/test_line_match_simple/old_added_whitespace.java b/tests/test_data/test_detection/test_line_match_simple/old_added_whitespace.java new file mode 100644 index 00000000..f9901034 --- /dev/null +++ b/tests/test_data/test_detection/test_line_match_simple/old_added_whitespace.java @@ -0,0 +1,7 @@ +package kai.example; + +class NoChange { + public NoChange() { + System.out.println("No change"); + } +} \ No newline at end of file diff --git a/tests/test_data/test_detection/test_line_match_simple/old_added_whitespace.yaml b/tests/test_data/test_detection/test_line_match_simple/old_added_whitespace.yaml new file mode 100644 index 00000000..7cd780a4 --- /dev/null +++ b/tests/test_data/test_detection/test_line_match_simple/old_added_whitespace.yaml @@ -0,0 +1,8 @@ +- violation_name: test_violation + ruleset_name: test_ruleset + application_name: test_application + incident_uri: added_whitespace.java + incident_message: Test incident message + incident_snip: "" + incident_line: 4 + incident_variables: {} diff --git a/tests/test_data/test_detection/test_line_match_simple/old_exact_matches.java b/tests/test_data/test_detection/test_line_match_simple/old_exact_matches.java new file mode 100644 index 00000000..f9901034 --- /dev/null +++ b/tests/test_data/test_detection/test_line_match_simple/old_exact_matches.java @@ -0,0 +1,7 @@ +package kai.example; + +class NoChange { + public NoChange() { + System.out.println("No change"); + } +} \ No newline at end of file diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 00000000..e51643b8 --- /dev/null +++ b/tests/util.py @@ -0,0 +1,38 @@ +from typing import Any + + +class Wrapper: + """ + Wrapper class to intercept and log all attribute access and method calls on + an object. + """ + + def __init__(self, obj: Any): + self.obj: Any = obj + self.callable_results: list = [] + + def __getattr__(self, attr: Any): + print(f"Getting {type(self.obj).__name__}.{attr}") + + result = getattr(self.obj, attr) + if callable(result): + return self.CallableWrapper(self, result) + + return result + + class CallableWrapper: + def __init__(self, parent: "Wrapper", callable: Any): + self.parent = parent + self.callable = callable + + def __call__(self, *args, **kwargs): + print(f"Calling {type(self.parent.obj).__name__}.{self.callable.__name__}") + + for i, arg in enumerate(args): + print(f" arg {i}: {arg}") + for key, value in kwargs.items(): + print(f" {key}: {value}") + + result = self.callable(*args, **kwargs) + self.parent.callable_results.append(result) + return result