From 818d405a85c2f1f5db9d673e632677c10cb52ad9 Mon Sep 17 00:00:00 2001 From: Emmanuel Ogbizi Date: Fri, 30 Oct 2020 16:03:28 -0400 Subject: [PATCH] perf: optimise session items data structures (#403) * refactor: optimise session data structures * refactor: type hints for pytest location class * refactor: reuse ran items generator --- src/syrupy/__init__.py | 14 ++++---------- src/syrupy/location.py | 25 ++++++++++++++++++------- src/syrupy/report.py | 36 ++++++++++++++++++++++++------------ src/syrupy/session.py | 23 +++++++++++++++++------ 4 files changed, 63 insertions(+), 35 deletions(-) diff --git a/src/syrupy/__init__.py b/src/syrupy/__init__.py index c80d7eb4..9a3346b2 100644 --- a/src/syrupy/__init__.py +++ b/src/syrupy/__init__.py @@ -127,8 +127,7 @@ def pytest_collection_modifyitems( After tests are collected and before any modification is performed. https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_collection_modifyitems """ - for item in config._syrupy.filter_valid_items(items): - config._syrupy._all_items[item] = True + config._syrupy.collect_items(items) def pytest_collection_finish(session: Any) -> None: @@ -136,8 +135,7 @@ def pytest_collection_finish(session: Any) -> None: After collection has been performed and modified. https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_collection_finish """ - for item in session.config._syrupy.filter_valid_items(session.items): - session.config._syrupy._ran_items[item] = False + session.config._syrupy.select_items(session.items) def pytest_runtest_logfinish(nodeid: str) -> None: @@ -146,12 +144,8 @@ def pytest_runtest_logfinish(nodeid: str) -> None: https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_runtest_logfinish """ global _syrupy - if not _syrupy: - return - for item in _syrupy._ran_items: - if getattr(item, "nodeid", None) == nodeid: - _syrupy._ran_items[item] = True - return + if _syrupy: + _syrupy.ran_item(nodeid) def pytest_sessionfinish(session: Any, exitstatus: int) -> None: diff --git a/src/syrupy/location.py b/src/syrupy/location.py index 4102d841..edb22e6c 100644 --- a/src/syrupy/location.py +++ b/src/syrupy/location.py @@ -1,19 +1,29 @@ from pathlib import Path from typing import ( - Any, Iterator, Optional, ) +import attr +import pytest + from syrupy.constants import PYTEST_NODE_SEP +@attr.s class PyTestLocation: - def __init__(self, node: Any): - self._node = node - self.filepath = self._node.fspath - self.modulename = self._node.obj.__module__ - self.methodname = self._node.obj.__name__ + _node: "pytest.Item" = attr.ib() + nodename: str = attr.ib(init=False) + testname: str = attr.ib(init=False) + methodname: str = attr.ib(init=False) + modulename: str = attr.ib(init=False) + filepath: str = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: + self.filepath = getattr(self._node, "fspath", None) + obj = getattr(self._node, "obj", None) + self.modulename = obj.__module__ + self.methodname = obj.__name__ self.nodename = getattr(self._node, "name", None) self.testname = self.nodename or self.methodname @@ -23,7 +33,8 @@ def classname(self) -> Optional[str]: Pytest node names contain file path and module members delimited by `::` Example tests/grouping/test_file.py::TestClass::TestSubClass::test_method """ - return ".".join(self._node.nodeid.split(PYTEST_NODE_SEP)[1:-1]) or None + nodeid: str = getattr(self._node, "nodeid", None) + return ".".join(nodeid.split(PYTEST_NODE_SEP)[1:-1]) or None @property def filename(self) -> str: diff --git a/src/syrupy/report.py b/src/syrupy/report.py index 9586efe9..0a78ae03 100644 --- a/src/syrupy/report.py +++ b/src/syrupy/report.py @@ -49,8 +49,8 @@ class SnapshotReport: """ base_dir: str = attr.ib() - all_items: Dict["pytest.Item", bool] = attr.ib() - ran_items: Dict["pytest.Item", bool] = attr.ib() + collected_items: Set["pytest.Item"] = attr.ib() + selected_items: Dict[str, bool] = attr.ib() update_snapshots: bool = attr.ib() warn_unused_snapshots: bool = attr.ib() assertions: List["SnapshotAssertion"] = attr.ib() @@ -63,9 +63,15 @@ class SnapshotReport: _invocation_args: Tuple[str, ...] = attr.ib(factory=tuple) _provided_test_paths: Dict[str, List[str]] = attr.ib(factory=dict) _keyword_expressions: Set["Expression"] = attr.ib(factory=set) + _collected_items_by_nodeid: Dict[str, "pytest.Item"] = attr.ib( + factory=dict, init=False + ) def __attrs_post_init__(self) -> None: self.__parse_invocation_args() + self._collected_items_by_nodeid = { + getattr(item, "nodeid", None): item for item in self.collected_items + } for assertion in self.assertions: self.discovered.merge(assertion.extension.discover_snapshots()) for result in assertion.executions.values(): @@ -154,8 +160,16 @@ def num_unused(self) -> int: return self._count_snapshots(self.unused) @property - def ran_all_collected_tests(self) -> bool: - return self.all_items == self.ran_items + def selected_all_collected_items(self) -> bool: + return self._collected_items_by_nodeid.keys() == self.selected_items.keys() + + @property + def ran_items(self) -> Iterator["pytest.Item"]: + return ( + self._collected_items_by_nodeid[nodeid] + for nodeid in self.selected_items + if self.selected_items[nodeid] + ) @property def unused(self) -> "SnapshotFossils": @@ -172,15 +186,15 @@ def unused(self) -> "SnapshotFossils": self.discovered, self.used ): snapshot_location = unused_snapshot_fossil.location - if self._provided_test_paths and not self._selected_items_match_location( + if self._provided_test_paths and not self._ran_items_match_location( snapshot_location ): - # Paths/Packages were provided to pytest and the snapshot location - # does not match therefore ignore this unused snapshot fossil file + # Paths/Packages were provided to pytest and the snapshot location does + # not match any of ran tests therefore ignore this unused snapshot file continue provided_nodes = self._get_matching_path_nodes(snapshot_location) - if self.ran_all_collected_tests and not any(provided_nodes): + if self.selected_all_collected_items and not any(provided_nodes): # All collected tests were run and files were not filtered by ::node # therefore the snapshot fossil file at this location can be deleted unused_snapshots = {*unused_snapshot_fossil} @@ -358,7 +372,6 @@ def _ran_items_match_name(self, snapshot_name: str) -> bool: return any( PyTestLocation(item).matches_snapshot_name(snapshot_name) for item in self.ran_items - if self.ran_items[item] ) def _selected_items_match_name(self, snapshot_name: str) -> bool: @@ -370,16 +383,15 @@ def _selected_items_match_name(self, snapshot_name: str) -> bool: return self._provided_keywords_match_name(snapshot_name) return self._ran_items_match_name(snapshot_name) - def _selected_items_match_location(self, snapshot_location: str) -> bool: + def _ran_items_match_location(self, snapshot_location: str) -> bool: """ - Check that a snapshot fossil location should is selected by the current session + Check if any test run in the current session should match the snapshot location This being true means that if no snapshot in the fossil was used then it should be discarded as obsolete """ return any( PyTestLocation(item).matches_snapshot_location(snapshot_location) for item in self.ran_items - if self.ran_items[item] ) diff --git a/src/syrupy/session.py b/src/syrupy/session.py index a8cde5d3..af4ad649 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -5,6 +5,7 @@ Iterable, List, Optional, + Set, Tuple, ) @@ -28,25 +29,35 @@ class SnapshotSession: _invocation_args: Tuple[str, ...] = attr.ib(factory=tuple) report: Optional["SnapshotReport"] = attr.ib(default=None) # All the collected test items - _all_items: Dict["pytest.Item", bool] = attr.ib(factory=dict) + _collected_items: Set["pytest.Item"] = attr.ib(factory=set) # All the selected test items. Will be set to False until the test item is run. - _ran_items: Dict["pytest.Item", bool] = attr.ib(factory=dict) + _selected_items: Dict[str, bool] = attr.ib(factory=dict) _assertions: List["SnapshotAssertion"] = attr.ib(factory=list) _extensions: Dict[str, "AbstractSyrupyExtension"] = attr.ib(factory=dict) + def collect_items(self, items: List["pytest.Item"]) -> None: + self._collected_items.update(self.filter_valid_items(items)) + + def select_items(self, items: List["pytest.Item"]) -> None: + for item in self.filter_valid_items(items): + self._selected_items[getattr(item, "nodeid", None)] = False + def start(self) -> None: self.report = None - self._all_items = {} - self._ran_items = {} + self._collected_items = set() + self._selected_items = {} self._assertions = [] self._extensions = {} + def ran_item(self, nodeid: str) -> None: + self._selected_items[nodeid] = True + def finish(self) -> int: exitstatus = 0 self.report = SnapshotReport( base_dir=self.base_dir, - all_items=self._all_items, - ran_items=self._ran_items, + collected_items=self._collected_items, + selected_items=self._selected_items, assertions=self._assertions, update_snapshots=self.update_snapshots, warn_unused_snapshots=self.warn_unused_snapshots,