Skip to content

Commit

Permalink
perf: cache session snapshot extension discovery
Browse files Browse the repository at this point in the history
  • Loading branch information
Noah Negin-Ulster committed Aug 20, 2021
1 parent c01eb54 commit 238885c
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/syrupy/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections import defaultdict
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -34,6 +36,10 @@ class SnapshotSession:
_assertions: List["SnapshotAssertion"] = attr.ib(factory=list)
_extensions: Dict[str, "AbstractSyrupyExtension"] = attr.ib(factory=dict)

_locations_discovered: DefaultDict[str, Set[Any]] = attr.ib(
factory=lambda: defaultdict(set)
)

@property
def update_snapshots(self) -> bool:
return bool(self._pytest_session.config.option.update_snapshots)
Expand All @@ -55,6 +61,7 @@ def start(self) -> None:
self._selected_items = {}
self._assertions = []
self._extensions = {}
self._locations_discovered = defaultdict(set)

def ran_item(self, nodeid: str) -> None:
self._selected_items[nodeid] = True
Expand All @@ -80,12 +87,17 @@ def finish(self) -> int:

def register_request(self, assertion: "SnapshotAssertion") -> None:
self._assertions.append(assertion)
discovered_extensions = {
discovered.location: assertion.extension
for discovered in assertion.extension.discover_snapshots()
if discovered.has_snapshots
}
self._extensions.update(discovered_extensions)

test_location = assertion.extension.test_location.filepath
extension_class = assertion.extension.__class__
if extension_class not in self._locations_discovered[test_location]:
self._locations_discovered[test_location].add(extension_class)
discovered_extensions = {
discovered.location: assertion.extension
for discovered in assertion.extension.discover_snapshots()
if discovered.has_snapshots
}
self._extensions.update(discovered_extensions)

def remove_unused_snapshots(
self,
Expand Down

0 comments on commit 238885c

Please sign in to comment.