Skip to content

Commit

Permalink
refactor: rename Fossil to Collection
Browse files Browse the repository at this point in the history
BREAKING CHANGE: The term 'fossil' has been replaced by the clearer term 'collection'.
  • Loading branch information
Noah Negin-Ulster committed Dec 1, 2022
1 parent bbc7ef2 commit 7413123
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 154 deletions.
4 changes: 2 additions & 2 deletions src/syrupy/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
SNAPSHOT_DIRNAME = "__snapshots__"
SNAPSHOT_EMPTY_FOSSIL_KEY = "empty snapshot fossil"
SNAPSHOT_UNKNOWN_FOSSIL_KEY = "unknown snapshot fossil"
SNAPSHOT_EMPTY_FOSSIL_KEY = "empty snapshot collection"
SNAPSHOT_UNKNOWN_FOSSIL_KEY = "unknown snapshot collection"

EXIT_STATUS_FAIL_UNUSED = 1

Expand Down
52 changes: 26 additions & 26 deletions src/syrupy/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SnapshotUnknown(Snapshot):


@dataclass
class SnapshotFossil:
class SnapshotCollection:
"""A collection of snapshots at a save location"""

location: str
Expand All @@ -54,8 +54,8 @@ def add(self, snapshot: "Snapshot") -> None:
if snapshot.name != SNAPSHOT_EMPTY_FOSSIL_KEY:
self.remove(SNAPSHOT_EMPTY_FOSSIL_KEY)

def merge(self, snapshot_fossil: "SnapshotFossil") -> None:
for snapshot in snapshot_fossil:
def merge(self, snapshot_collection: "SnapshotCollection") -> None:
for snapshot in snapshot_collection:
self.add(snapshot)

def remove(self, snapshot_name: str) -> None:
Expand All @@ -69,8 +69,8 @@ def __iter__(self) -> Iterator["Snapshot"]:


@dataclass
class SnapshotEmptyFossil(SnapshotFossil):
"""This is a saved fossil that is known to be empty and thus can be removed"""
class SnapshotEmptyCollection(SnapshotCollection):
"""This is a saved collection that is known to be empty and thus can be removed"""

_snapshots: Dict[str, "Snapshot"] = field(
default_factory=lambda: {SnapshotEmpty().name: SnapshotEmpty()}
Expand All @@ -82,42 +82,42 @@ def has_snapshots(self) -> bool:


@dataclass
class SnapshotUnknownFossil(SnapshotFossil):
"""This is a saved fossil that is unclaimed by any extension currently in use"""
class SnapshotUnknownCollection(SnapshotCollection):
"""This is a saved collection that is unclaimed by any extension currently in use"""

_snapshots: Dict[str, "Snapshot"] = field(
default_factory=lambda: {SnapshotUnknown().name: SnapshotUnknown()}
)


@dataclass
class SnapshotFossils:
_snapshot_fossils: Dict[str, "SnapshotFossil"] = field(default_factory=dict)
class SnapshotCollections:
_snapshot_collections: Dict[str, "SnapshotCollection"] = field(default_factory=dict)

def get(self, location: str) -> Optional["SnapshotFossil"]:
return self._snapshot_fossils.get(location)
def get(self, location: str) -> Optional["SnapshotCollection"]:
return self._snapshot_collections.get(location)

def add(self, snapshot_fossil: "SnapshotFossil") -> None:
self._snapshot_fossils[snapshot_fossil.location] = snapshot_fossil
def add(self, snapshot_collection: "SnapshotCollection") -> None:
self._snapshot_collections[snapshot_collection.location] = snapshot_collection

def update(self, snapshot_fossil: "SnapshotFossil") -> None:
snapshot_fossil_to_update = self.get(snapshot_fossil.location)
if snapshot_fossil_to_update is None:
snapshot_fossil_to_update = SnapshotFossil(
location=snapshot_fossil.location
def update(self, snapshot_collection: "SnapshotCollection") -> None:
snapshot_collection_to_update = self.get(snapshot_collection.location)
if snapshot_collection_to_update is None:
snapshot_collection_to_update = SnapshotCollection(
location=snapshot_collection.location
)
self.add(snapshot_fossil_to_update)
snapshot_fossil_to_update.merge(snapshot_fossil)
self.add(snapshot_collection_to_update)
snapshot_collection_to_update.merge(snapshot_collection)

def merge(self, snapshot_fossils: "SnapshotFossils") -> None:
for snapshot_fossil in snapshot_fossils:
self.update(snapshot_fossil)
def merge(self, snapshot_collections: "SnapshotCollections") -> None:
for snapshot_collection in snapshot_collections:
self.update(snapshot_collection)

def __iter__(self) -> Iterator["SnapshotFossil"]:
return iter(self._snapshot_fossils.values())
def __iter__(self) -> Iterator["SnapshotCollection"]:
return iter(self._snapshot_collections.values())

def __contains__(self, key: str) -> bool:
return key in self._snapshot_fossils
return key in self._snapshot_collections


@dataclass
Expand Down
20 changes: 11 additions & 9 deletions src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Set,
)

from syrupy.data import SnapshotFossil
from syrupy.data import SnapshotCollection
from syrupy.extensions.base import AbstractSyrupyExtension

from .serializer import DataSerializer
Expand All @@ -33,23 +33,23 @@ def serialize(self, data: "SerializableData", **kwargs: Any) -> str:
def delete_snapshots(
self, snapshot_location: str, snapshot_names: Set[str]
) -> None:
snapshot_fossil_to_update = DataSerializer.read_file(snapshot_location)
snapshot_collection_to_update = DataSerializer.read_file(snapshot_location)
for snapshot_name in snapshot_names:
snapshot_fossil_to_update.remove(snapshot_name)
snapshot_collection_to_update.remove(snapshot_name)

if snapshot_fossil_to_update.has_snapshots:
DataSerializer.write_file(snapshot_fossil_to_update)
if snapshot_collection_to_update.has_snapshots:
DataSerializer.write_file(snapshot_collection_to_update)
else:
Path(snapshot_location).unlink()

def _read_snapshot_fossil(self, snapshot_location: str) -> "SnapshotFossil":
def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
return DataSerializer.read_file(snapshot_location)

@staticmethod
@lru_cache()
def __cacheable_read_snapshot(
snapshot_location: str, cache_key: str
) -> "SnapshotFossil":
) -> "SnapshotCollection":
return DataSerializer.read_file(snapshot_location)

def _read_snapshot_data_from_location(
Expand All @@ -61,8 +61,10 @@ def _read_snapshot_data_from_location(
snapshot = snapshots.get(snapshot_name)
return snapshot.data if snapshot else None

def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None:
DataSerializer.write_file(snapshot_fossil, merge=True)
def _write_snapshot_collection(
self, *, snapshot_collection: "SnapshotCollection"
) -> None:
DataSerializer.write_file(snapshot_collection, merge=True)


__all__ = ["AmberSnapshotExtension", "DataSerializer"]
22 changes: 12 additions & 10 deletions src/syrupy/extensions/amber/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from syrupy.data import (
Snapshot,
SnapshotFossil,
SnapshotCollection,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -70,18 +70,20 @@ class DataSerializer:
_marker_crn: str = "\r\n"

@classmethod
def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> None:
def write_file(
cls, snapshot_collection: "SnapshotCollection", merge: bool = False
) -> None:
"""
Writes the snapshot data into the snapshot file that can be read later.
"""
filepath = snapshot_fossil.location
filepath = snapshot_collection.location
if merge:
base_snapshot = cls.read_file(filepath)
base_snapshot.merge(snapshot_fossil)
snapshot_fossil = base_snapshot
base_snapshot.merge(snapshot_collection)
snapshot_collection = base_snapshot

with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f:
for snapshot in sorted(snapshot_fossil, key=lambda s: s.name):
for snapshot in sorted(snapshot_collection, key=lambda s: s.name):
snapshot_data = str(snapshot.data)
if snapshot_data is not None:
f.write(f"{cls._marker_name} {snapshot.name}\n")
Expand All @@ -90,15 +92,15 @@ def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> N
f.write(f"\n{cls._marker_divider}\n")

@classmethod
def read_file(cls, filepath: str) -> "SnapshotFossil":
def read_file(cls, filepath: str) -> "SnapshotCollection":
"""
Read the raw snapshot data (str) from the snapshot file into a dict
of snapshot name to raw data. This does not attempt any deserialization
of the snapshot data.
"""
name_marker_len = len(cls._marker_name)
indent_len = len(cls._indent)
snapshot_fossil = SnapshotFossil(location=filepath)
snapshot_collection = SnapshotCollection(location=filepath)
try:
with open(filepath, "r", encoding=TEXT_ENCODING, newline=None) as f:
test_name = None
Expand All @@ -112,7 +114,7 @@ def read_file(cls, filepath: str) -> "SnapshotFossil":
if line.startswith(cls._indent):
snapshot_data += line[indent_len:]
elif line.startswith(cls._marker_divider) and snapshot_data:
snapshot_fossil.add(
snapshot_collection.add(
Snapshot(
name=test_name,
data=snapshot_data.rstrip(os.linesep),
Expand All @@ -121,7 +123,7 @@ def read_file(cls, filepath: str) -> "SnapshotFossil":
except FileNotFoundError:
pass

return snapshot_fossil
return snapshot_collection

@classmethod
def serialize(
Expand Down
50 changes: 28 additions & 22 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from syrupy.data import (
DiffedLine,
Snapshot,
SnapshotEmptyFossil,
SnapshotFossil,
SnapshotFossils,
SnapshotCollection,
SnapshotCollections,
SnapshotEmptyCollection,
)
from syrupy.exceptions import SnapshotDoesNotExist
from syrupy.terminal import (
Expand Down Expand Up @@ -76,7 +76,7 @@ def serialize(
raise NotImplementedError


class SnapshotFossilizer(ABC):
class SnapshotCollectionizer(ABC):
_file_extension = ""

@property
Expand Down Expand Up @@ -106,20 +106,22 @@ def is_snapshot_location(self, *, location: str) -> bool:
"""Checks if supplied location is valid for this snapshot extension"""
return location.endswith(self._file_extension)

def discover_snapshots(self) -> "SnapshotFossils":
def discover_snapshots(self) -> "SnapshotCollections":
"""
Returns all snapshot fossils in test site
Returns all snapshot collections in test site
"""
discovered: "SnapshotFossils" = SnapshotFossils()
discovered: "SnapshotCollections" = SnapshotCollections()
for filepath in walk_snapshot_dir(self._dirname):
if self.is_snapshot_location(location=filepath):
snapshot_fossil = self._read_snapshot_fossil(snapshot_location=filepath)
if not snapshot_fossil.has_snapshots:
snapshot_fossil = SnapshotEmptyFossil(location=filepath)
snapshot_collection = self._read_snapshot_collection(
snapshot_location=filepath
)
if not snapshot_collection.has_snapshots:
snapshot_collection = SnapshotEmptyCollection(location=filepath)
else:
snapshot_fossil = SnapshotFossil(location=filepath)
snapshot_collection = SnapshotCollection(location=filepath)

discovered.add(snapshot_fossil)
discovered.add(snapshot_collection)

return discovered

Expand All @@ -146,7 +148,7 @@ def read_snapshot(
def write_snapshot(self, *, data: "SerializedData", index: "SnapshotIndex") -> None:
"""
This method is _final_, do not override. You can override
`_write_snapshot_fossil` in a subclass to change behaviour.
`_write_snapshot_collection` in a subclass to change behaviour.
"""
self.write_snapshot_batch(snapshots=[(data, index)])

Expand All @@ -155,7 +157,7 @@ def write_snapshot_batch(
) -> None:
"""
This method is _final_, do not override. You can override
`_write_snapshot_fossil` in a subclass to change behaviour.
`_write_snapshot_collection` in a subclass to change behaviour.
"""
# First we group by location since it'll let us batch by file on disk.
# Not as useful for single file snapshots, but useful for the standard
Expand All @@ -171,7 +173,7 @@ def write_snapshot_batch(
self.__ensure_snapshot_dir(index=index)

for location, location_snapshots in locations.items():
snapshot_fossil = SnapshotFossil(location=location)
snapshot_collection = SnapshotCollection(location=location)

if not self.test_location.matches_snapshot_location(location):
warning_msg = gettext(
Expand All @@ -186,7 +188,7 @@ def write_snapshot_batch(
warnings.warn(warning_msg)

for snapshot in location_snapshots:
snapshot_fossil.add(snapshot)
snapshot_collection.add(snapshot)

if not self.test_location.matches_snapshot_name(snapshot.name):
warning_msg = gettext(
Expand All @@ -200,7 +202,7 @@ def write_snapshot_batch(
)
warnings.warn(warning_msg)

self._write_snapshot_fossil(snapshot_fossil=snapshot_fossil)
self._write_snapshot_collection(snapshot_collection=snapshot_collection)

@abstractmethod
def delete_snapshots(
Expand All @@ -213,9 +215,11 @@ def delete_snapshots(
raise NotImplementedError

@abstractmethod
def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil":
def _read_snapshot_collection(
self, *, snapshot_location: str
) -> "SnapshotCollection":
"""
Read the snapshot location and construct a snapshot fossil object
Read the snapshot location and construct a snapshot collection object
"""
raise NotImplementedError

Expand All @@ -229,9 +233,11 @@ def _read_snapshot_data_from_location(
raise NotImplementedError

@abstractmethod
def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None:
def _write_snapshot_collection(
self, *, snapshot_collection: "SnapshotCollection"
) -> None:
"""
Adds the snapshot data to the snapshots in fossil location
Adds the snapshot data to the snapshots in collection location
"""
raise NotImplementedError

Expand Down Expand Up @@ -415,7 +421,7 @@ def matches(


class AbstractSyrupyExtension(
SnapshotSerializer, SnapshotFossilizer, SnapshotReporter, SnapshotComparator
SnapshotSerializer, SnapshotCollectionizer, SnapshotReporter, SnapshotComparator
):
def __init__(self, test_location: "PyTestLocation"):
self._test_location = test_location
Expand Down
Loading

2 comments on commit 7413123

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 7413123 Previous: 23cca84 Ratio
benchmarks/test_1000x.py::test_1000x_reads 0.5723090270718078 iter/sec (stddev: 0.07204337789956015) 0.6754078596653935 iter/sec (stddev: 0.06391877117699159) 1.18
benchmarks/test_1000x.py::test_1000x_writes 0.21993982515181057 iter/sec (stddev: 0.08281832140164301) 0.6345993135561808 iter/sec (stddev: 0.23174880874067105) 2.89
benchmarks/test_standard.py::test_standard 0.5410139429475035 iter/sec (stddev: 0.0840057617906888) 0.6315599143065584 iter/sec (stddev: 0.0923523543680502) 1.17

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 7413123 Previous: 23cca84 Ratio
benchmarks/test_1000x.py::test_1000x_writes 0.21993982515181057 iter/sec (stddev: 0.08281832140164301) 0.6345993135561808 iter/sec (stddev: 0.23174880874067105) 2.89

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.