Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support overriding the amber serializer class #683

Merged
merged 1 commit into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
Any,
Optional,
Set,
Type,
)

from syrupy.data import SnapshotCollection
from syrupy.exceptions import TaintedSnapshotError
from syrupy.extensions.base import AbstractSyrupyExtension

from .serializer import AmberDataSerializerSorted # noqa: F401 # re-exported
from .serializer import AmberDataSerializer

if TYPE_CHECKING:
Expand All @@ -24,12 +26,14 @@ class AmberSnapshotExtension(AbstractSyrupyExtension):

_file_extension = "ambr"

serializer_class: Type["AmberDataSerializer"] = AmberDataSerializer

def serialize(self, data: "SerializableData", **kwargs: Any) -> str:
"""
Returns the serialized form of 'data' to be compared
with the snapshot data written to disk.
"""
return AmberDataSerializer.serialize(data, **kwargs)
return self.serializer_class.serialize(data, **kwargs)

def delete_snapshots(
self, snapshot_location: str, snapshot_names: Set[str]
Expand All @@ -39,19 +43,19 @@ def delete_snapshots(
snapshot_collection_to_update.remove(snapshot_name)

if snapshot_collection_to_update.has_snapshots:
AmberDataSerializer.write_file(snapshot_collection_to_update)
self.serializer_class.write_file(snapshot_collection_to_update)
else:
Path(snapshot_location).unlink()

def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
return AmberDataSerializer.read_file(snapshot_location)
return self.serializer_class.read_file(snapshot_location)

@staticmethod
@classmethod
@lru_cache()
def __cacheable_read_snapshot(
snapshot_location: str, cache_key: str
cls, snapshot_location: str, cache_key: str
) -> "SnapshotCollection":
return AmberDataSerializer.read_file(snapshot_location)
return cls.serializer_class.read_file(snapshot_location)

def _read_snapshot_data_from_location(
self, snapshot_location: str, snapshot_name: str, session_id: str
Expand All @@ -70,7 +74,7 @@ def _read_snapshot_data_from_location(
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
) -> None:
AmberDataSerializer.write_file(snapshot_collection, merge=True)
cls.serializer_class.write_file(snapshot_collection, merge=True)


__all__ = ["AmberSnapshotExtension", "AmberDataSerializer"]
59 changes: 37 additions & 22 deletions src/syrupy/extensions/amber/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Dict,
Generator,
Iterable,
List,
NamedTuple,
Optional,
Set,
Expand Down Expand Up @@ -77,7 +76,13 @@ class MissingVersionError(Exception):


class AmberDataSerializer:
VERSION = 1
"""
If extending the serializer, change the VERSION property to some unique value
for your iteration of the serializer so as to force invalidation of existing
snapshots.
"""

VERSION = "1"

_indent: str = " "
_max_depth: int = 99
Expand All @@ -89,23 +94,8 @@ class Marker:
Divider = "---"

@classmethod
def __maybe_int(cls, part: str) -> Tuple[int, Union[str, int]]:
try:
# cast to int only if the string is the exact representation of the int
# for example, '012' != str(int('012'))
i = int(part)
if str(i) == part:
return (1, i)
return (0, part)
except ValueError:
# the nested tuple is to prevent comparing a str to an int
return (0, part)

@classmethod
def __snapshot_sort_key(
cls, snapshot: "Snapshot"
) -> List[Tuple[int, Union[str, int]]]:
return [cls.__maybe_int(part) for part in snapshot.name.split(".")]
def _snapshot_sort_key(cls, snapshot: "Snapshot") -> Any:
return snapshot.name

@classmethod
def write_file(
Expand All @@ -123,7 +113,7 @@ def write_file(
with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f:
f.write(f"{cls._marker_prefix}{cls.Marker.Version}: {cls.VERSION}\n")
for snapshot in sorted(
snapshot_collection, key=lambda s: cls.__snapshot_sort_key(s)
snapshot_collection, key=lambda s: cls._snapshot_sort_key(s) # type: ignore # noqa: E501
):
snapshot_data = str(snapshot.data)
if snapshot_data is not None:
Expand Down Expand Up @@ -152,14 +142,14 @@ def __read_file_with_markers(
":", maxsplit=1
)
marker_key = marker_key.rstrip(" \r\n")
marker_value = marker_rest[0] if marker_rest else None
marker_value = marker_rest[0].strip() if marker_rest else None

if marker_key == cls.Marker.Version:
if line_no:
raise MalformedAmberFile(
"Version must be specified at the top of the file."
)
if not marker_value or int(marker_value) != cls.VERSION:
if not marker_value or marker_value != cls.VERSION:
tainted = True
continue
missing_version = False
Expand Down Expand Up @@ -457,3 +447,28 @@ def __serialize_lines(
formatted_open_tag = cls.with_indent(f"{maybe_obj_type}{open_tag}", depth)
formatted_close_tag = cls.with_indent(close_tag, depth)
return f"{formatted_open_tag}\n{lines}{lines_end}{formatted_close_tag}"


class AmberDataSerializerSorted(AmberDataSerializer):
"""
This is an experimental serializer with known performance issues.
"""

VERSION = f"{AmberDataSerializer.VERSION}-sorted"

@classmethod
def __maybe_int(cls, part: str) -> Tuple[int, Union[str, int]]:
try:
# cast to int only if the string is the exact representation of the int
# for example, '012' != str(int('012'))
i = int(part)
if str(i) == part:
return (1, i)
return (0, part)
except ValueError:
# the nested tuple is to prevent comparing a str to an int
return (0, part)

@classmethod
def _snapshot_sort_key(cls, snapshot: "Snapshot") -> Any:
return [cls.__maybe_int(part) for part in snapshot.name.split(".")]
10 changes: 5 additions & 5 deletions tests/syrupy/__snapshots__/test_doctest.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
obj_attr='test class attr',
)
# ---
# name: DocTestClass.1
DocTestClass(
obj_attr='test class attr',
)
# ---
# name: DocTestClass.NestedDocTestClass
NestedDocTestClass(
nested_obj_attr='nested doc test class attr',
Expand All @@ -15,11 +20,6 @@
# name: DocTestClass.doctest_method
'doc test method return value'
# ---
# name: DocTestClass.1
DocTestClass(
obj_attr='test class attr',
)
# ---
# name: doctest_fn
'doc test fn return value'
# ---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,30 +253,6 @@
# name: test_many_sorted.1
1
# ---
# name: test_many_sorted.2
2
# ---
# name: test_many_sorted.3
3
# ---
# name: test_many_sorted.4
4
# ---
# name: test_many_sorted.5
5
# ---
# name: test_many_sorted.6
6
# ---
# name: test_many_sorted.7
7
# ---
# name: test_many_sorted.8
8
# ---
# name: test_many_sorted.9
9
# ---
# name: test_many_sorted.10
10
# ---
Expand Down Expand Up @@ -307,6 +283,9 @@
# name: test_many_sorted.19
19
# ---
# name: test_many_sorted.2
2
# ---
# name: test_many_sorted.20
20
# ---
Expand All @@ -322,6 +301,27 @@
# name: test_many_sorted.24
24
# ---
# name: test_many_sorted.3
3
# ---
# name: test_many_sorted.4
4
# ---
# name: test_many_sorted.5
5
# ---
# name: test_many_sorted.6
6
# ---
# name: test_many_sorted.7
7
# ---
# name: test_many_sorted.8
8
# ---
# name: test_many_sorted.9
9
# ---
# name: test_multiline_string_in_dict
dict({
'value': '''
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# serializer version: 1-sorted
# name: test_many_sorted
0
# ---
# name: test_many_sorted.1
1
# ---
# name: test_many_sorted.2
2
# ---
# name: test_many_sorted.3
3
# ---
# name: test_many_sorted.4
4
# ---
# name: test_many_sorted.5
5
# ---
# name: test_many_sorted.6
6
# ---
# name: test_many_sorted.7
7
# ---
# name: test_many_sorted.8
8
# ---
# name: test_many_sorted.9
9
# ---
# name: test_many_sorted.10
10
# ---
# name: test_many_sorted.11
11
# ---
# name: test_many_sorted.12
12
# ---
# name: test_many_sorted.13
13
# ---
# name: test_many_sorted.14
14
# ---
# name: test_many_sorted.15
15
# ---
# name: test_many_sorted.16
16
# ---
# name: test_many_sorted.17
17
# ---
# name: test_many_sorted.18
18
# ---
# name: test_many_sorted.19
19
# ---
# name: test_many_sorted.20
20
# ---
# name: test_many_sorted.21
21
# ---
# name: test_many_sorted.22
22
# ---
# name: test_many_sorted.23
23
# ---
# name: test_many_sorted.24
24
# ---
20 changes: 20 additions & 0 deletions tests/syrupy/extensions/amber_sorted/test_amber_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from syrupy.extensions.amber import (
AmberDataSerializerSorted,
AmberSnapshotExtension,
)


class AmberSortedSnapshotExtension(AmberSnapshotExtension):
serializer_class = AmberDataSerializerSorted


@pytest.fixture
def snapshot(snapshot):
return snapshot.use_extension(AmberSortedSnapshotExtension)


def test_many_sorted(snapshot):
for i in range(25):
assert i == snapshot