From f416e9a408bc252fdeab1888cb8b7f0cd8fdef92 Mon Sep 17 00:00:00 2001 From: Noah Negin-Ulster Date: Thu, 2 Feb 2023 12:43:38 -0500 Subject: [PATCH] feat: support overriding the amber serializer class --- src/syrupy/extensions/amber/__init__.py | 18 +++-- src/syrupy/extensions/amber/serializer.py | 59 ++++++++------ tests/syrupy/__snapshots__/test_doctest.ambr | 10 +-- .../__snapshots__/test_amber_serializer.ambr | 48 ++++++------ .../__snapshots__/test_amber_sort.ambr | 76 +++++++++++++++++++ .../amber_sorted/test_amber_sort.py | 20 +++++ 6 files changed, 173 insertions(+), 58 deletions(-) create mode 100644 tests/syrupy/extensions/amber_sorted/__snapshots__/test_amber_sort.ambr create mode 100644 tests/syrupy/extensions/amber_sorted/test_amber_sort.py diff --git a/src/syrupy/extensions/amber/__init__.py b/src/syrupy/extensions/amber/__init__.py index 4b54ee2d..74dbc33b 100644 --- a/src/syrupy/extensions/amber/__init__.py +++ b/src/syrupy/extensions/amber/__init__.py @@ -5,12 +5,14 @@ Any, Optional, Set, + Type, ) from syrupy.data import SnapshotCollection from syrupy.exceptions import TaintedSnapshotError from syrupy.extensions.base import AbstractSyrupyExtension +from .serializer import AmberDataSerializerSorted # noqa: F401 # re-exported from .serializer import AmberDataSerializer if TYPE_CHECKING: @@ -24,12 +26,14 @@ class AmberSnapshotExtension(AbstractSyrupyExtension): _file_extension = "ambr" + serializer_class: Type["AmberDataSerializer"] = AmberDataSerializer + def serialize(self, data: "SerializableData", **kwargs: Any) -> str: """ Returns the serialized form of 'data' to be compared with the snapshot data written to disk. """ - return AmberDataSerializer.serialize(data, **kwargs) + return self.serializer_class.serialize(data, **kwargs) def delete_snapshots( self, snapshot_location: str, snapshot_names: Set[str] @@ -39,19 +43,19 @@ def delete_snapshots( snapshot_collection_to_update.remove(snapshot_name) if snapshot_collection_to_update.has_snapshots: - AmberDataSerializer.write_file(snapshot_collection_to_update) + self.serializer_class.write_file(snapshot_collection_to_update) else: Path(snapshot_location).unlink() def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection": - return AmberDataSerializer.read_file(snapshot_location) + return self.serializer_class.read_file(snapshot_location) - @staticmethod + @classmethod @lru_cache() def __cacheable_read_snapshot( - snapshot_location: str, cache_key: str + cls, snapshot_location: str, cache_key: str ) -> "SnapshotCollection": - return AmberDataSerializer.read_file(snapshot_location) + return cls.serializer_class.read_file(snapshot_location) def _read_snapshot_data_from_location( self, snapshot_location: str, snapshot_name: str, session_id: str @@ -70,7 +74,7 @@ def _read_snapshot_data_from_location( def _write_snapshot_collection( cls, *, snapshot_collection: "SnapshotCollection" ) -> None: - AmberDataSerializer.write_file(snapshot_collection, merge=True) + cls.serializer_class.write_file(snapshot_collection, merge=True) __all__ = ["AmberSnapshotExtension", "AmberDataSerializer"] diff --git a/src/syrupy/extensions/amber/serializer.py b/src/syrupy/extensions/amber/serializer.py index 10a82648..4a0c7035 100644 --- a/src/syrupy/extensions/amber/serializer.py +++ b/src/syrupy/extensions/amber/serializer.py @@ -11,7 +11,6 @@ Dict, Generator, Iterable, - List, NamedTuple, Optional, Set, @@ -77,7 +76,13 @@ class MissingVersionError(Exception): class AmberDataSerializer: - VERSION = 1 + """ + If extending the serializer, change the VERSION property to some unique value + for your iteration of the serializer so as to force invalidation of existing + snapshots. + """ + + VERSION = "1" _indent: str = " " _max_depth: int = 99 @@ -89,23 +94,8 @@ class Marker: Divider = "---" @classmethod - def __maybe_int(cls, part: str) -> Tuple[int, Union[str, int]]: - try: - # cast to int only if the string is the exact representation of the int - # for example, '012' != str(int('012')) - i = int(part) - if str(i) == part: - return (1, i) - return (0, part) - except ValueError: - # the nested tuple is to prevent comparing a str to an int - return (0, part) - - @classmethod - def __snapshot_sort_key( - cls, snapshot: "Snapshot" - ) -> List[Tuple[int, Union[str, int]]]: - return [cls.__maybe_int(part) for part in snapshot.name.split(".")] + def _snapshot_sort_key(cls, snapshot: "Snapshot") -> Any: + return snapshot.name @classmethod def write_file( @@ -123,7 +113,7 @@ def write_file( with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f: f.write(f"{cls._marker_prefix}{cls.Marker.Version}: {cls.VERSION}\n") for snapshot in sorted( - snapshot_collection, key=lambda s: cls.__snapshot_sort_key(s) + snapshot_collection, key=lambda s: cls._snapshot_sort_key(s) # type: ignore # noqa: E501 ): snapshot_data = str(snapshot.data) if snapshot_data is not None: @@ -152,14 +142,14 @@ def __read_file_with_markers( ":", maxsplit=1 ) marker_key = marker_key.rstrip(" \r\n") - marker_value = marker_rest[0] if marker_rest else None + marker_value = marker_rest[0].strip() if marker_rest else None if marker_key == cls.Marker.Version: if line_no: raise MalformedAmberFile( "Version must be specified at the top of the file." ) - if not marker_value or int(marker_value) != cls.VERSION: + if not marker_value or marker_value != cls.VERSION: tainted = True continue missing_version = False @@ -457,3 +447,28 @@ def __serialize_lines( formatted_open_tag = cls.with_indent(f"{maybe_obj_type}{open_tag}", depth) formatted_close_tag = cls.with_indent(close_tag, depth) return f"{formatted_open_tag}\n{lines}{lines_end}{formatted_close_tag}" + + +class AmberDataSerializerSorted(AmberDataSerializer): + """ + This is an experimental serializer with known performance issues. + """ + + VERSION = f"{AmberDataSerializer.VERSION}-sorted" + + @classmethod + def __maybe_int(cls, part: str) -> Tuple[int, Union[str, int]]: + try: + # cast to int only if the string is the exact representation of the int + # for example, '012' != str(int('012')) + i = int(part) + if str(i) == part: + return (1, i) + return (0, part) + except ValueError: + # the nested tuple is to prevent comparing a str to an int + return (0, part) + + @classmethod + def _snapshot_sort_key(cls, snapshot: "Snapshot") -> Any: + return [cls.__maybe_int(part) for part in snapshot.name.split(".")] diff --git a/tests/syrupy/__snapshots__/test_doctest.ambr b/tests/syrupy/__snapshots__/test_doctest.ambr index c6229fd3..0b3e9696 100644 --- a/tests/syrupy/__snapshots__/test_doctest.ambr +++ b/tests/syrupy/__snapshots__/test_doctest.ambr @@ -4,6 +4,11 @@ obj_attr='test class attr', ) # --- +# name: DocTestClass.1 + DocTestClass( + obj_attr='test class attr', + ) +# --- # name: DocTestClass.NestedDocTestClass NestedDocTestClass( nested_obj_attr='nested doc test class attr', @@ -15,11 +20,6 @@ # name: DocTestClass.doctest_method 'doc test method return value' # --- -# name: DocTestClass.1 - DocTestClass( - obj_attr='test class attr', - ) -# --- # name: doctest_fn 'doc test fn return value' # --- diff --git a/tests/syrupy/extensions/amber/__snapshots__/test_amber_serializer.ambr b/tests/syrupy/extensions/amber/__snapshots__/test_amber_serializer.ambr index 8956e365..72f06006 100644 --- a/tests/syrupy/extensions/amber/__snapshots__/test_amber_serializer.ambr +++ b/tests/syrupy/extensions/amber/__snapshots__/test_amber_serializer.ambr @@ -253,30 +253,6 @@ # name: test_many_sorted.1 1 # --- -# name: test_many_sorted.2 - 2 -# --- -# name: test_many_sorted.3 - 3 -# --- -# name: test_many_sorted.4 - 4 -# --- -# name: test_many_sorted.5 - 5 -# --- -# name: test_many_sorted.6 - 6 -# --- -# name: test_many_sorted.7 - 7 -# --- -# name: test_many_sorted.8 - 8 -# --- -# name: test_many_sorted.9 - 9 -# --- # name: test_many_sorted.10 10 # --- @@ -307,6 +283,9 @@ # name: test_many_sorted.19 19 # --- +# name: test_many_sorted.2 + 2 +# --- # name: test_many_sorted.20 20 # --- @@ -322,6 +301,27 @@ # name: test_many_sorted.24 24 # --- +# name: test_many_sorted.3 + 3 +# --- +# name: test_many_sorted.4 + 4 +# --- +# name: test_many_sorted.5 + 5 +# --- +# name: test_many_sorted.6 + 6 +# --- +# name: test_many_sorted.7 + 7 +# --- +# name: test_many_sorted.8 + 8 +# --- +# name: test_many_sorted.9 + 9 +# --- # name: test_multiline_string_in_dict dict({ 'value': ''' diff --git a/tests/syrupy/extensions/amber_sorted/__snapshots__/test_amber_sort.ambr b/tests/syrupy/extensions/amber_sorted/__snapshots__/test_amber_sort.ambr new file mode 100644 index 00000000..29c217df --- /dev/null +++ b/tests/syrupy/extensions/amber_sorted/__snapshots__/test_amber_sort.ambr @@ -0,0 +1,76 @@ +# serializer version: 1-sorted +# name: test_many_sorted + 0 +# --- +# name: test_many_sorted.1 + 1 +# --- +# name: test_many_sorted.2 + 2 +# --- +# name: test_many_sorted.3 + 3 +# --- +# name: test_many_sorted.4 + 4 +# --- +# name: test_many_sorted.5 + 5 +# --- +# name: test_many_sorted.6 + 6 +# --- +# name: test_many_sorted.7 + 7 +# --- +# name: test_many_sorted.8 + 8 +# --- +# name: test_many_sorted.9 + 9 +# --- +# name: test_many_sorted.10 + 10 +# --- +# name: test_many_sorted.11 + 11 +# --- +# name: test_many_sorted.12 + 12 +# --- +# name: test_many_sorted.13 + 13 +# --- +# name: test_many_sorted.14 + 14 +# --- +# name: test_many_sorted.15 + 15 +# --- +# name: test_many_sorted.16 + 16 +# --- +# name: test_many_sorted.17 + 17 +# --- +# name: test_many_sorted.18 + 18 +# --- +# name: test_many_sorted.19 + 19 +# --- +# name: test_many_sorted.20 + 20 +# --- +# name: test_many_sorted.21 + 21 +# --- +# name: test_many_sorted.22 + 22 +# --- +# name: test_many_sorted.23 + 23 +# --- +# name: test_many_sorted.24 + 24 +# --- diff --git a/tests/syrupy/extensions/amber_sorted/test_amber_sort.py b/tests/syrupy/extensions/amber_sorted/test_amber_sort.py new file mode 100644 index 00000000..3a2007a9 --- /dev/null +++ b/tests/syrupy/extensions/amber_sorted/test_amber_sort.py @@ -0,0 +1,20 @@ +import pytest + +from syrupy.extensions.amber import ( + AmberDataSerializerSorted, + AmberSnapshotExtension, +) + + +class AmberSortedSnapshotExtension(AmberSnapshotExtension): + serializer_class = AmberDataSerializerSorted + + +@pytest.fixture +def snapshot(snapshot): + return snapshot.use_extension(AmberSortedSnapshotExtension) + + +def test_many_sorted(snapshot): + for i in range(25): + assert i == snapshot