Skip to content

Commit

Permalink
perf: optimise session items data structures (#403)
Browse files Browse the repository at this point in the history
* refactor: optimise session data structures

* refactor: type hints for pytest location class

* refactor: reuse ran items generator
  • Loading branch information
iamogbz committed Oct 30, 2020
1 parent 5c77f44 commit 818d405
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 35 deletions.
14 changes: 4 additions & 10 deletions src/syrupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,15 @@ def pytest_collection_modifyitems(
After tests are collected and before any modification is performed.
https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_collection_modifyitems
"""
for item in config._syrupy.filter_valid_items(items):
config._syrupy._all_items[item] = True
config._syrupy.collect_items(items)


def pytest_collection_finish(session: Any) -> None:
"""
After collection has been performed and modified.
https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_collection_finish
"""
for item in session.config._syrupy.filter_valid_items(session.items):
session.config._syrupy._ran_items[item] = False
session.config._syrupy.select_items(session.items)


def pytest_runtest_logfinish(nodeid: str) -> None:
Expand All @@ -146,12 +144,8 @@ def pytest_runtest_logfinish(nodeid: str) -> None:
https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_runtest_logfinish
"""
global _syrupy
if not _syrupy:
return
for item in _syrupy._ran_items:
if getattr(item, "nodeid", None) == nodeid:
_syrupy._ran_items[item] = True
return
if _syrupy:
_syrupy.ran_item(nodeid)


def pytest_sessionfinish(session: Any, exitstatus: int) -> None:
Expand Down
25 changes: 18 additions & 7 deletions src/syrupy/location.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from pathlib import Path
from typing import (
Any,
Iterator,
Optional,
)

import attr
import pytest

from syrupy.constants import PYTEST_NODE_SEP


@attr.s
class PyTestLocation:
def __init__(self, node: Any):
self._node = node
self.filepath = self._node.fspath
self.modulename = self._node.obj.__module__
self.methodname = self._node.obj.__name__
_node: "pytest.Item" = attr.ib()
nodename: str = attr.ib(init=False)
testname: str = attr.ib(init=False)
methodname: str = attr.ib(init=False)
modulename: str = attr.ib(init=False)
filepath: str = attr.ib(init=False)

def __attrs_post_init__(self) -> None:
self.filepath = getattr(self._node, "fspath", None)
obj = getattr(self._node, "obj", None)
self.modulename = obj.__module__
self.methodname = obj.__name__
self.nodename = getattr(self._node, "name", None)
self.testname = self.nodename or self.methodname

Expand All @@ -23,7 +33,8 @@ def classname(self) -> Optional[str]:
Pytest node names contain file path and module members delimited by `::`
Example tests/grouping/test_file.py::TestClass::TestSubClass::test_method
"""
return ".".join(self._node.nodeid.split(PYTEST_NODE_SEP)[1:-1]) or None
nodeid: str = getattr(self._node, "nodeid", None)
return ".".join(nodeid.split(PYTEST_NODE_SEP)[1:-1]) or None

@property
def filename(self) -> str:
Expand Down
36 changes: 24 additions & 12 deletions src/syrupy/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class SnapshotReport:
"""

base_dir: str = attr.ib()
all_items: Dict["pytest.Item", bool] = attr.ib()
ran_items: Dict["pytest.Item", bool] = attr.ib()
collected_items: Set["pytest.Item"] = attr.ib()
selected_items: Dict[str, bool] = attr.ib()
update_snapshots: bool = attr.ib()
warn_unused_snapshots: bool = attr.ib()
assertions: List["SnapshotAssertion"] = attr.ib()
Expand All @@ -63,9 +63,15 @@ class SnapshotReport:
_invocation_args: Tuple[str, ...] = attr.ib(factory=tuple)
_provided_test_paths: Dict[str, List[str]] = attr.ib(factory=dict)
_keyword_expressions: Set["Expression"] = attr.ib(factory=set)
_collected_items_by_nodeid: Dict[str, "pytest.Item"] = attr.ib(
factory=dict, init=False
)

def __attrs_post_init__(self) -> None:
self.__parse_invocation_args()
self._collected_items_by_nodeid = {
getattr(item, "nodeid", None): item for item in self.collected_items
}
for assertion in self.assertions:
self.discovered.merge(assertion.extension.discover_snapshots())
for result in assertion.executions.values():
Expand Down Expand Up @@ -154,8 +160,16 @@ def num_unused(self) -> int:
return self._count_snapshots(self.unused)

@property
def ran_all_collected_tests(self) -> bool:
return self.all_items == self.ran_items
def selected_all_collected_items(self) -> bool:
return self._collected_items_by_nodeid.keys() == self.selected_items.keys()

@property
def ran_items(self) -> Iterator["pytest.Item"]:
return (
self._collected_items_by_nodeid[nodeid]
for nodeid in self.selected_items
if self.selected_items[nodeid]
)

@property
def unused(self) -> "SnapshotFossils":
Expand All @@ -172,15 +186,15 @@ def unused(self) -> "SnapshotFossils":
self.discovered, self.used
):
snapshot_location = unused_snapshot_fossil.location
if self._provided_test_paths and not self._selected_items_match_location(
if self._provided_test_paths and not self._ran_items_match_location(
snapshot_location
):
# Paths/Packages were provided to pytest and the snapshot location
# does not match therefore ignore this unused snapshot fossil file
# Paths/Packages were provided to pytest and the snapshot location does
# not match any of ran tests therefore ignore this unused snapshot file
continue

provided_nodes = self._get_matching_path_nodes(snapshot_location)
if self.ran_all_collected_tests and not any(provided_nodes):
if self.selected_all_collected_items and not any(provided_nodes):
# All collected tests were run and files were not filtered by ::node
# therefore the snapshot fossil file at this location can be deleted
unused_snapshots = {*unused_snapshot_fossil}
Expand Down Expand Up @@ -358,7 +372,6 @@ def _ran_items_match_name(self, snapshot_name: str) -> bool:
return any(
PyTestLocation(item).matches_snapshot_name(snapshot_name)
for item in self.ran_items
if self.ran_items[item]
)

def _selected_items_match_name(self, snapshot_name: str) -> bool:
Expand All @@ -370,16 +383,15 @@ def _selected_items_match_name(self, snapshot_name: str) -> bool:
return self._provided_keywords_match_name(snapshot_name)
return self._ran_items_match_name(snapshot_name)

def _selected_items_match_location(self, snapshot_location: str) -> bool:
def _ran_items_match_location(self, snapshot_location: str) -> bool:
"""
Check that a snapshot fossil location should is selected by the current session
Check if any test run in the current session should match the snapshot location
This being true means that if no snapshot in the fossil was used then it should
be discarded as obsolete
"""
return any(
PyTestLocation(item).matches_snapshot_location(snapshot_location)
for item in self.ran_items
if self.ran_items[item]
)


Expand Down
23 changes: 17 additions & 6 deletions src/syrupy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Iterable,
List,
Optional,
Set,
Tuple,
)

Expand All @@ -28,25 +29,35 @@ class SnapshotSession:
_invocation_args: Tuple[str, ...] = attr.ib(factory=tuple)
report: Optional["SnapshotReport"] = attr.ib(default=None)
# All the collected test items
_all_items: Dict["pytest.Item", bool] = attr.ib(factory=dict)
_collected_items: Set["pytest.Item"] = attr.ib(factory=set)
# All the selected test items. Will be set to False until the test item is run.
_ran_items: Dict["pytest.Item", bool] = attr.ib(factory=dict)
_selected_items: Dict[str, bool] = attr.ib(factory=dict)
_assertions: List["SnapshotAssertion"] = attr.ib(factory=list)
_extensions: Dict[str, "AbstractSyrupyExtension"] = attr.ib(factory=dict)

def collect_items(self, items: List["pytest.Item"]) -> None:
self._collected_items.update(self.filter_valid_items(items))

def select_items(self, items: List["pytest.Item"]) -> None:
for item in self.filter_valid_items(items):
self._selected_items[getattr(item, "nodeid", None)] = False

def start(self) -> None:
self.report = None
self._all_items = {}
self._ran_items = {}
self._collected_items = set()
self._selected_items = {}
self._assertions = []
self._extensions = {}

def ran_item(self, nodeid: str) -> None:
self._selected_items[nodeid] = True

def finish(self) -> int:
exitstatus = 0
self.report = SnapshotReport(
base_dir=self.base_dir,
all_items=self._all_items,
ran_items=self._ran_items,
collected_items=self._collected_items,
selected_items=self._selected_items,
assertions=self._assertions,
update_snapshots=self.update_snapshots,
warn_unused_snapshots=self.warn_unused_snapshots,
Expand Down

0 comments on commit 818d405

Please sign in to comment.