diff --git a/.prettierrc.toml b/.prettierrc.toml new file mode 100644 index 00000000..3d99dbf3 --- /dev/null +++ b/.prettierrc.toml @@ -0,0 +1 @@ +tabWidth = 2 diff --git a/README.md b/README.md index d09f9249..a399787c 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ def __repr__(self) -> str: return "MyCustomClass(...)" ``` -### Options +### CLI Options These are the cli options exposed to `pytest` by the plugin. @@ -80,7 +80,75 @@ These are the cli options exposed to `pytest` by the plugin. | `--snapshot-warn-unused` | Prints a warning on unused snapshots rather than fail the test suite. | `False` | | `--snapshot-default-extension` | Use to change the default snapshot extension class. | `syrupy.extensions.amber.AmberSnapshotExtension` | -### Built-In Extensions +### Assertion Options + +These are the options available on the `snapshot` assertion fixture. +Use of these options are one shot and do not persist across assertions. +For more persistent options see [advanced usage](#advanced-usage). + +#### `matcher` + +This allows you to match on a property path and value to control how specific object shapes are serialized. + +The matcher is a function that takes two keyword arguments. +It should return the replacement value to be serialized or the original unmutated value. + +| Argument | Description | +| -------- | ------------------------------------------------------------------------------------------------------------------ | +| `data` | Current serializable value being matched on | +| `path` | Ordered path traversed to the current value e.g. `(("a", dict), ("b", dict))` from `{ "a": { "b": { "c": 1 } } }`} | + +**NOTE:** Do not mutate the value received as it could cause unintended side effects. + +##### Built-In Matchers + +Syrupy comes with built-in helpers that can be used to make easy work of using property matchers. + +###### `path_type(mapping=None, *, types=(), strict=True)` + +Easy way to build a matcher that uses the path and value type to replace serialized data. +When strict, this will raise a `ValueError` if the types specified are not matched. + +| Argument | Description | +| --------- | ---------------------------------------------------------------------------------------------------------------------------------- | +| `mapping` | Dict of path string to tuples of class types, including primitives e.g. (MyClass, UUID, datetime, int, str) | +| `types` | Tuple of class types used if none of the path strings from the mapping are matched | +| `strict` | If a path is matched but the value at the path does not match one of the class types in the tuple then a `PathTypeError` is raised | + +```py +from syrupy.matchers import path_type + +def test_bar(snapshot): + actual = { + "date_created": datetime.now(), + "value": "Some computed value!", + } + assert actual == snapshot(matcher=path_type({ + "date_created": (datetime,), + "nested.path.id": (int,), + })) +``` + +```ambr +# name: test_bar + { + 'date_created': , + 'value': 'Some computed value!', + } +--- +``` + +#### `extension_class` + +This is a way to modify how the snapshot matches and serializes your data in a single assertion. + +```py +def test_foo(snapshot): + actual_svg = "" + assert actual_svg = snapshot(extension_class=SVGImageSnapshotExtension) +``` + +##### Built-In Extensions Syrupy comes with a few built-in preset configurations for you to choose from. You should also feel free to extend the `AbstractSyrupyExtension` if your project has a need not captured by one our built-ins. @@ -141,7 +209,9 @@ To develop locally, clone this repository and run `. script/bootstrap` to instal + + This section is automatically generated via tagging the all-contributors bot in a PR: ```text diff --git a/src/syrupy/__init__.py b/src/syrupy/__init__.py index 2f85926a..106b70d0 100644 --- a/src/syrupy/__init__.py +++ b/src/syrupy/__init__.py @@ -67,12 +67,12 @@ def pytest_assertrepr_compare(op: str, left: Any, right: Any) -> Optional[List[s assert_msg = reset( f"{snapshot_style(left.name)} {op} {received_style('received')}" ) - return [assert_msg] + left.get_assert_diff(right) + return [assert_msg] + left.get_assert_diff() elif isinstance(right, SnapshotAssertion): assert_msg = reset( f"{received_style('received')} {op} {snapshot_style(right.name)}" ) - return [assert_msg] + right.get_assert_diff(left) + return [assert_msg] + right.get_assert_diff() return None diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 49690cb0..de911e8b 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -17,7 +17,7 @@ from .location import TestLocation from .extensions.base import AbstractSyrupyExtension from .session import SnapshotSession - from .types import SerializableData, SerializedData # noqa: F401 + from .types import PropertyMatcher, SerializableData, SerializedData @attr.s @@ -45,6 +45,7 @@ class SnapshotAssertion: _test_location: "TestLocation" = attr.ib(kw_only=True) _update_snapshots: bool = attr.ib(kw_only=True) _extension: Optional["AbstractSyrupyExtension"] = attr.ib(init=False, default=None) + _matcher: Optional["PropertyMatcher"] = attr.ib(init=False, default=None) _executions: int = attr.ib(init=False, default=0) _execution_results: Dict[int, "AssertionResult"] = attr.ib(init=False, factory=dict) _post_assert_actions: List[Callable[..., None]] = attr.ib(init=False, factory=list) @@ -88,10 +89,13 @@ def use_extension( def assert_match(self, data: "SerializableData") -> None: assert self == data - def get_assert_diff(self, data: "SerializableData") -> List[str]: + def _serialize(self, data: "SerializableData") -> "SerializedData": + return self.extension.serialize(data, matcher=self._matcher) + + def get_assert_diff(self) -> List[str]: assertion_result = self._execution_results[self.num_executions - 1] snapshot_data = assertion_result.recalled_data - serialized_data = self.extension.serialize(data) + serialized_data = assertion_result.asserted_data or "" diff: List[str] = [] if snapshot_data is None: diff.append(gettext("Snapshot does not exist!")) @@ -100,7 +104,10 @@ def get_assert_diff(self, data: "SerializableData") -> List[str]: return diff def __call__( - self, *, extension_class: Optional[Type["AbstractSyrupyExtension"]] + self, + *, + extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, + matcher: Optional["PropertyMatcher"] = None, ) -> "SnapshotAssertion": """ Modifies assertion instance options @@ -112,6 +119,13 @@ def clear_extension() -> None: self._extension = None self._post_assert_actions.append(clear_extension) + if matcher: + self._matcher = matcher + + def clear_matcher() -> None: + self._matcher = None + + self._post_assert_actions.append(clear_matcher) return self def __repr__(self) -> str: @@ -131,7 +145,7 @@ def _assert(self, data: "SerializableData") -> bool: assertion_success = False try: snapshot_data = self._recall_data(index=self.num_executions) - serialized_data = self.extension.serialize(data) + serialized_data = self._serialize(data) matches = snapshot_data is not None and serialized_data == snapshot_data assertion_success = matches if not matches and self._update_snapshots: diff --git a/src/syrupy/data.py b/src/syrupy/data.py index 1ddcd8b4..049208c8 100644 --- a/src/syrupy/data.py +++ b/src/syrupy/data.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: - from .types import SerializedData # noqa: F401 + from .types import SerializedData @attr.s(frozen=True) diff --git a/src/syrupy/extensions/amber.py b/src/syrupy/extensions/amber.py deleted file mode 100644 index 1c25de7e..00000000 --- a/src/syrupy/extensions/amber.py +++ /dev/null @@ -1,325 +0,0 @@ -from pathlib import Path -from types import GeneratorType -from typing import ( - TYPE_CHECKING, - Any, - Iterable, - Optional, - Set, -) - -from syrupy.constants import SYMBOL_ELLIPSIS -from syrupy.data import ( - Snapshot, - SnapshotFossil, -) - -from .base import AbstractSyrupyExtension - - -if TYPE_CHECKING: - from syrupy.types import SerializableData - - -class DataSerializer: - _indent: str = " " - _max_depth: int = 99 - _marker_divider: str = "---" - _marker_name: str = "# name:" - - class MarkerDepthMax: - def __repr__(self) -> str: - return SYMBOL_ELLIPSIS - - @classmethod - def write_file(cls, snapshot_fossil: "SnapshotFossil") -> None: - """ - Writes the snapshot data into the snapshot file that be read later. - """ - filepath = snapshot_fossil.location - with open(filepath, "w", encoding="utf-8", newline="") as f: - for snapshot in sorted(snapshot_fossil, key=lambda s: s.name): - snapshot_data = str(snapshot.data) - if snapshot_data is not None: - f.write(f"{cls._marker_name} {snapshot.name}\n") - for data_line in snapshot_data.splitlines(keepends=True): - f.write(f"{cls._indent}{data_line}") - f.write(f"\n{cls._marker_divider}\n") - - @classmethod - def read_file(cls, filepath: str) -> "SnapshotFossil": - """ - Read the raw snapshot data (str) from the snapshot file into a dict - of snapshot name to raw data. This does not attempt any deserialization - of the snapshot data. - """ - name_marker_len = len(cls._marker_name) - indent_len = len(cls._indent) - snapshot_fossil = SnapshotFossil(location=filepath) - try: - with open(filepath, "r", encoding="utf-8", newline="") as f: - test_name = None - snapshot_data = "" - for line in f: - if line.startswith(cls._marker_name): - test_name = line[name_marker_len:-1].strip(" \r\n") - snapshot_data = "" - continue - elif test_name is not None: - if line.startswith(cls._indent): - snapshot_data += line[indent_len:] - elif line.startswith(cls._marker_divider) and snapshot_data: - snapshot_fossil.add( - Snapshot(name=test_name, data=snapshot_data[:-1]) - ) - except FileNotFoundError: - pass - - return snapshot_fossil - - @classmethod - def sort(cls, iterable: Iterable[Any]) -> Iterable[Any]: - try: - return sorted(iterable) - except TypeError: - return sorted(iterable, key=cls.serialize) - - @classmethod - def with_indent(cls, string: str, depth: int) -> str: - return f"{cls._indent * depth}{string}" - - @classmethod - def object_type(cls, data: "SerializableData") -> str: - return f"" - - @classmethod - def serialize_string( - cls, - data: "SerializableData", - *, - depth: int = 0, - visited: Optional[Set[Any]] = None, - ) -> str: - if "\n" in data: - return ( - cls.with_indent("'\n", depth) - + "".join( - cls.with_indent(line, depth + 1 if depth else depth) - for line in str(data).splitlines(keepends=True) - ) - + "\n" - + cls.with_indent("'", depth) - ) - return cls.with_indent(repr(data), depth) - - @classmethod - def serialize_number( - cls, - data: "SerializableData", - *, - depth: int = 0, - visited: Optional[Set[Any]] = None, - ) -> str: - return cls.with_indent(repr(data), depth) - - @classmethod - def serialize_set( - cls, - data: "SerializableData", - *, - depth: int = 0, - visited: Optional[Set[Any]] = None, - ) -> str: - return ( - cls.with_indent(f"{cls.object_type(data)} {{\n", depth) - + "".join( - f"{cls.serialize(d, depth=depth + 1, visited=visited)},\n" - for d in cls.sort(data) - ) - + cls.with_indent("}", depth) - ) - - @classmethod - def serialize_dict( - cls, - data: "SerializableData", - *, - depth: int = 0, - visited: Optional[Set[Any]] = None, - ) -> str: - kwargs = {"depth": depth + 1, "visited": visited} - return ( - cls.with_indent(f"{cls.object_type(data)} {{\n", depth) - + "".join( - f"{serialized_key}: {serialized_value.lstrip(cls._indent)},\n" - for serialized_key, serialized_value in ( - ( - cls.serialize(**{"data": key, **kwargs}), - cls.serialize(**{"data": data[key], **kwargs}), - ) - for key in cls.sort(data.keys()) - ) - ) - + cls.with_indent("}", depth) - ) - - @classmethod - def __is_namedtuple(cls, obj: Any) -> bool: - return isinstance(obj, tuple) and all( - type(n) == str for n in getattr(obj, "_fields", [None]) - ) - - @classmethod - def serialize_namedtuple( - cls, data: Any, *, depth: int = 0, visited: Optional[Set[Any]] = None - ) -> str: - return ( - cls.with_indent(f"{cls.object_type(data)} (\n", depth) - + "".join( - f"{serialized_key}={serialized_value.lstrip(cls._indent)},\n" - for serialized_key, serialized_value in ( - ( - cls.with_indent(name, depth=depth + 1), - cls.serialize( - data=getattr(data, name), depth=depth + 1, visited=visited - ), - ) - for name in cls.sort(data._fields) - ) - ) - + cls.with_indent(")", depth) - ) - - @classmethod - def serialize_iterable( - cls, - data: "SerializableData", - *, - depth: int = 0, - visited: Optional[Set[Any]] = None, - ) -> str: - open_paren, close_paren = next( - parens - for iter_type, parens in { - GeneratorType: ("(", ")"), - list: ("[", "]"), - tuple: ("(", ")"), - }.items() - if isinstance(data, iter_type) - ) - return ( - cls.with_indent(f"{cls.object_type(data)} {open_paren}\n", depth) - + "".join( - f"{cls.serialize(d, depth=depth + 1, visited=visited)},\n" for d in data - ) - + cls.with_indent(close_paren, depth) - ) - - @classmethod - def serialize_unknown( - cls, data: Any, *, depth: int = 0, visited: Optional[Set[Any]] = None - ) -> str: - if data.__class__.__repr__ != object.__repr__: - return cls.with_indent(repr(data), depth) - - return ( - cls.with_indent(f"{cls.object_type(data)} {{\n", depth) - + "".join( - f"{serialized_key}={serialized_value.lstrip(cls._indent)},\n" - for serialized_key, serialized_value in ( - ( - cls.with_indent(name, depth=depth + 1), - cls.serialize( - data=getattr(data, name), depth=depth + 1, visited=visited - ), - ) - for name in cls.sort(dir(data)) - if not name.startswith("_") and not callable(getattr(data, name)) - ) - ) - + cls.with_indent("}", depth) - ) - - @classmethod - def serialize( - cls, - data: "SerializableData", - *, - depth: int = 0, - visited: Optional[Set[Any]] = None, - ) -> str: - visited = visited if visited is not None else set() - data_id = id(data) - if depth > cls._max_depth or data_id in visited: - data = cls.MarkerDepthMax() - - serialize_kwargs = { - "data": data, - "depth": depth, - "visited": {*visited, data_id}, - } - serialize_method = cls.serialize_unknown - if isinstance(data, str): - serialize_method = cls.serialize_string - elif isinstance(data, (int, float)): - serialize_method = cls.serialize_number - elif isinstance(data, (set, frozenset)): - serialize_method = cls.serialize_set - elif isinstance(data, dict): - serialize_method = cls.serialize_dict - elif cls.__is_namedtuple(data): - serialize_method = cls.serialize_namedtuple - elif isinstance(data, (list, tuple, GeneratorType)): - serialize_method = cls.serialize_iterable - return serialize_method(**serialize_kwargs) - - -class AmberSnapshotExtension(AbstractSyrupyExtension): - """ - An amber snapshot file stores data in the following format: - - ``` - # name: test_name_1 - data - --- - # name: test_name_2 - data - ``` - """ - - def serialize(self, data: "SerializableData") -> str: - """ - Returns the serialized form of 'data' to be compared - with the snapshot data written to disk. - """ - return DataSerializer.serialize(data) - - def delete_snapshots( - self, snapshot_location: str, snapshot_names: Set[str] - ) -> None: - snapshot_fossil_to_update = DataSerializer.read_file(snapshot_location) - for snapshot_name in snapshot_names: - snapshot_fossil_to_update.remove(snapshot_name) - - if snapshot_fossil_to_update.has_snapshots: - DataSerializer.write_file(snapshot_fossil_to_update) - else: - Path(snapshot_location).unlink() - - @property - def _file_extension(self) -> str: - return "ambr" - - def _read_snapshot_fossil(self, snapshot_location: str) -> "SnapshotFossil": - return DataSerializer.read_file(snapshot_location) - - def _read_snapshot_data_from_location( - self, snapshot_location: str, snapshot_name: str - ) -> Optional["SerializableData"]: - snapshot = self._read_snapshot_fossil(snapshot_location).get(snapshot_name) - return snapshot.data if snapshot else None - - def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None: - snapshot_fossil_to_update = DataSerializer.read_file(snapshot_fossil.location) - snapshot_fossil_to_update.merge(snapshot_fossil) - DataSerializer.write_file(snapshot_fossil_to_update) diff --git a/src/syrupy/extensions/amber/__init__.py b/src/syrupy/extensions/amber/__init__.py new file mode 100644 index 00000000..a5646252 --- /dev/null +++ b/src/syrupy/extensions/amber/__init__.py @@ -0,0 +1,71 @@ +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Optional, + Set, +) + +from syrupy.data import SnapshotFossil +from syrupy.extensions.base import AbstractSyrupyExtension + +from .serializer import DataSerializer + + +if TYPE_CHECKING: + from syrupy.types import PropertyMatcher, SerializableData + + +class AmberSnapshotExtension(AbstractSyrupyExtension): + """ + An amber snapshot file stores data in the following format: + + ``` + # name: test_name_1 + data + --- + # name: test_name_2 + data + ``` + """ + + def serialize( + self, data: "SerializableData", *, matcher: Optional["PropertyMatcher"] = None + ) -> str: + """ + Returns the serialized form of 'data' to be compared + with the snapshot data written to disk. + """ + return DataSerializer.serialize(data, matcher=matcher) + + def delete_snapshots( + self, snapshot_location: str, snapshot_names: Set[str] + ) -> None: + snapshot_fossil_to_update = DataSerializer.read_file(snapshot_location) + for snapshot_name in snapshot_names: + snapshot_fossil_to_update.remove(snapshot_name) + + if snapshot_fossil_to_update.has_snapshots: + DataSerializer.write_file(snapshot_fossil_to_update) + else: + Path(snapshot_location).unlink() + + @property + def _file_extension(self) -> str: + return "ambr" + + def _read_snapshot_fossil(self, snapshot_location: str) -> "SnapshotFossil": + return DataSerializer.read_file(snapshot_location) + + def _read_snapshot_data_from_location( + self, snapshot_location: str, snapshot_name: str + ) -> Optional["SerializableData"]: + snapshot = self._read_snapshot_fossil(snapshot_location).get(snapshot_name) + return snapshot.data if snapshot else None + + def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None: + snapshot_fossil_to_update = DataSerializer.read_file(snapshot_fossil.location) + snapshot_fossil_to_update.merge(snapshot_fossil) + DataSerializer.write_file(snapshot_fossil_to_update) + + +__all__ = ["AmberSnapshotExtension", "DataSerializer"] diff --git a/src/syrupy/extensions/amber/serializer.py b/src/syrupy/extensions/amber/serializer.py new file mode 100644 index 00000000..f18efab2 --- /dev/null +++ b/src/syrupy/extensions/amber/serializer.py @@ -0,0 +1,294 @@ +from types import GeneratorType +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Optional, + Set, + Tuple, +) + +from syrupy.constants import SYMBOL_ELLIPSIS +from syrupy.data import ( + Snapshot, + SnapshotFossil, +) + + +if TYPE_CHECKING: + from syrupy.types import ( + PropertyMatcher, + PropertyName, + PropertyPath, + SerializableData, + ) + + +class Repr: + def __init__(self, _repr: str): + self._repr = _repr + + def __repr__(self) -> str: + return self._repr + + +class DataSerializer: + _indent: str = " " + _max_depth: int = 99 + _marker_divider: str = "---" + _marker_name: str = "# name:" + + @classmethod + def write_file(cls, snapshot_fossil: "SnapshotFossil") -> None: + """ + Writes the snapshot data into the snapshot file that be read later. + """ + filepath = snapshot_fossil.location + with open(filepath, "w", encoding="utf-8", newline="") as f: + for snapshot in sorted(snapshot_fossil, key=lambda s: s.name): + snapshot_data = str(snapshot.data) + if snapshot_data is not None: + f.write(f"{cls._marker_name} {snapshot.name}\n") + for data_line in snapshot_data.splitlines(keepends=True): + f.write(f"{cls._indent}{data_line}") + f.write(f"\n{cls._marker_divider}\n") + + @classmethod + def read_file(cls, filepath: str) -> "SnapshotFossil": + """ + Read the raw snapshot data (str) from the snapshot file into a dict + of snapshot name to raw data. This does not attempt any deserialization + of the snapshot data. + """ + name_marker_len = len(cls._marker_name) + indent_len = len(cls._indent) + snapshot_fossil = SnapshotFossil(location=filepath) + try: + with open(filepath, "r", encoding="utf-8", newline="") as f: + test_name = None + snapshot_data = "" + for line in f: + if line.startswith(cls._marker_name): + test_name = line[name_marker_len:-1].strip(" \r\n") + snapshot_data = "" + continue + elif test_name is not None: + if line.startswith(cls._indent): + snapshot_data += line[indent_len:] + elif line.startswith(cls._marker_divider) and snapshot_data: + snapshot_fossil.add( + Snapshot(name=test_name, data=snapshot_data[:-1]) + ) + except FileNotFoundError: + pass + + return snapshot_fossil + + @classmethod + def serialize( + cls, + data: "SerializableData", + *, + depth: int = 0, + matcher: Optional["PropertyMatcher"] = None, + path: "PropertyPath" = (), + visited: Optional[Set[Any]] = None, + ) -> str: + visited = set() if visited is None else visited + data_id = id(data) + if depth > cls._max_depth or data_id in visited: + data = Repr(SYMBOL_ELLIPSIS) + elif matcher: + data = matcher(data=data, path=path) + serialize_kwargs = { + "data": data, + "depth": depth, + "matcher": matcher, + "path": path, + "visited": {*visited, data_id}, + } + serialize_method = cls.serialize_unknown + if isinstance(data, str): + serialize_method = cls.serialize_string + elif isinstance(data, (int, float)): + serialize_method = cls.serialize_number + elif isinstance(data, (set, frozenset)): + serialize_method = cls.serialize_set + elif isinstance(data, dict): + serialize_method = cls.serialize_dict + elif cls.__is_namedtuple(data): + serialize_method = cls.serialize_namedtuple + elif isinstance(data, (list, tuple, GeneratorType)): + serialize_method = cls.serialize_iterable + return serialize_method(**serialize_kwargs) + + @classmethod + def serialize_number( + cls, data: "SerializableData", *, depth: int = 0, **kwargs: Any + ) -> str: + return cls.with_indent(repr(data), depth) + + @classmethod + def serialize_string( + cls, data: "SerializableData", *, depth: int = 0, **kwargs: Any + ) -> str: + if "\n" in data: + return cls.__serialize_lines( + data=data, + lines=( + cls.with_indent(line, depth + 1 if depth else depth) + for line in str(data).splitlines(keepends=True) + ), + depth=depth, + open_tag="'", + close_tag="'", + include_type=False, + ends="", + ) + return cls.with_indent(repr(data), depth) + + @classmethod + def serialize_iterable(cls, data: "SerializableData", **kwargs: Any) -> str: + open_paren, close_paren = next( + parens + for iter_type, parens in { + GeneratorType: ("(", ")"), + list: ("[", "]"), + tuple: ("(", ")"), + }.items() + if isinstance(data, iter_type) + ) + return cls.__serialize_iterable( + data=data, + entries=enumerate(data), + open_tag=open_paren, + close_tag=close_paren, + **kwargs, + ) + + @classmethod + def serialize_set(cls, data: "SerializableData", **kwargs: Any) -> str: + return cls.__serialize_iterable( + data=data, + entries=((d, d) for d in cls.sort(data)), + open_tag="{", + close_tag="}", + **kwargs, + ) + + @classmethod + def serialize_namedtuple(cls, data: "SerializableData", **kwargs: Any) -> str: + return cls.__serialize_iterable( + data=data, + entries=((name, getattr(data, name)) for name in cls.sort(data._fields)), + open_tag="(", + close_tag=")", + separator="=", + **kwargs, + ) + + @classmethod + def serialize_dict(cls, data: "SerializableData", **kwargs: Any) -> str: + return cls.__serialize_iterable( + data=data, + entries=((key, data[key]) for key in cls.sort(data.keys())), + open_tag="{", + close_tag="}", + separator=": ", + serialize_key=True, + **kwargs, + ) + + @classmethod + def serialize_unknown(cls, data: Any, *, depth: int = 0, **kwargs: Any) -> str: + if data.__class__.__repr__ != object.__repr__: + return cls.with_indent(repr(data), depth) + + return cls.__serialize_iterable( + data=data, + entries=( + (name, getattr(data, name)) + for name in cls.sort(dir(data)) + if not name.startswith("_") and not callable(getattr(data, name)) + ), + depth=depth, + open_tag="{", + close_tag="}", + separator="=", + **kwargs, + ) + + @classmethod + def with_indent(cls, string: str, depth: int) -> str: + return f"{cls._indent * depth}{string}" + + @classmethod + def sort(cls, iterable: Iterable[Any]) -> Iterable[Any]: + try: + return sorted(iterable) + except TypeError: + return sorted(iterable, key=cls.serialize) + + @classmethod + def object_type(cls, data: "SerializableData") -> str: + return f"" + + @classmethod + def __is_namedtuple(cls, obj: Any) -> bool: + return isinstance(obj, tuple) and all( + type(n) == str for n in getattr(obj, "_fields", [None]) + ) + + @classmethod + def __serialize_lines( + cls, + *, + data: "SerializableData", + lines: Iterable[str], + open_tag: str, + close_tag: str, + depth: int = 0, + include_type: bool = True, + ends: str = "\n", + ) -> str: + return ( + f"{cls.with_indent(cls.object_type(data), depth)} " if include_type else "" + ) + f"{open_tag}\n{ends.join(lines)}\n{cls.with_indent(close_tag, depth)}" + + @classmethod + def __serialize_iterable( + cls, + *, + data: "SerializableData", + entries: Iterable[Tuple["PropertyName", "SerializableData"]], + open_tag: str, + close_tag: str, + depth: int = 0, + path: "PropertyPath" = (), + separator: Optional[str] = None, + serialize_key: bool = False, + **kwargs: Any, + ) -> str: + kwargs["depth"] = depth + 1 + + def key_str(key: "PropertyName") -> str: + if separator is None: + return "" + return ( + cls.serialize(data=key, **kwargs) + if serialize_key + else cls.with_indent(str(key), depth=depth + 1) + ) + separator + + def value_str(key: "PropertyName", value: "SerializableData") -> str: + _path = (*path, (key, type(value))) + serialized = cls.serialize(data=value, path=_path, **kwargs) + return serialized if separator is None else serialized.lstrip(cls._indent) + + return cls.__serialize_lines( + data=data, + lines=(f"{key_str(key)}{value_str(key, value)}," for key, value in entries), + depth=depth, + open_tag=open_tag, + close_tag=close_tag, + ) diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 59225a5e..79528423 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -46,12 +46,14 @@ if TYPE_CHECKING: from syrupy.location import TestLocation - from syrupy.types import SerializableData, SerializedData + from syrupy.types import PropertyMatcher, SerializableData, SerializedData class SnapshotSerializer(ABC): @abstractmethod - def serialize(self, data: "SerializableData") -> "SerializedData": + def serialize( + self, data: "SerializableData", *, matcher: Optional["PropertyMatcher"] = None, + ) -> "SerializedData": """ Serializes a python object / data structure into a string to be used for comparison with snapshot data from disk. @@ -337,7 +339,7 @@ def __limit_context(self, lines: List[str]) -> Iterator[str]: if num_lines: if num_lines > self._context_line_max: count_leading_whitespace: Callable[[str], int] = ( - lambda s: len(s) - len(s.lstrip()) # noqa: E731 + lambda s: len(s) - len(s.lstrip()) ) if self._context_line_count: num_space = ( diff --git a/src/syrupy/extensions/image.py b/src/syrupy/extensions/image.py index 72002fb3..f8ff381f 100644 --- a/src/syrupy/extensions/image.py +++ b/src/syrupy/extensions/image.py @@ -1,10 +1,13 @@ -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Optional, +) from .single_file import SingleFileSnapshotExtension if TYPE_CHECKING: - from syrupy.types import SerializableData + from syrupy.types import PropertyMatcher, SerializableData class PNGImageSnapshotExtension(SingleFileSnapshotExtension): @@ -18,5 +21,7 @@ class SVGImageSnapshotExtension(SingleFileSnapshotExtension): def _file_extension(self) -> str: return "svg" - def serialize(self, data: "SerializableData") -> bytes: + def serialize( + self, data: "SerializableData", *, matcher: Optional["PropertyMatcher"] = None, + ) -> bytes: return str(data).encode("utf-8") diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index 11decd76..fd6a7028 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -16,11 +16,13 @@ if TYPE_CHECKING: - from syrupy.types import SerializableData, SerializedData # noqa: F401 + from syrupy.types import PropertyMatcher, SerializableData class SingleFileSnapshotExtension(AbstractSyrupyExtension): - def serialize(self, data: "SerializableData") -> bytes: + def serialize( + self, data: "SerializableData", *, matcher: Optional["PropertyMatcher"] = None, + ) -> bytes: return bytes(data) def get_snapshot_name(self, *, index: int = 0) -> str: diff --git a/src/syrupy/matchers.py b/src/syrupy/matchers.py new file mode 100644 index 00000000..62741d69 --- /dev/null +++ b/src/syrupy/matchers.py @@ -0,0 +1,62 @@ +from gettext import gettext +from typing import ( + TYPE_CHECKING, + Dict, + Optional, + Tuple, +) + +from syrupy.extensions.amber.serializer import ( + DataSerializer, + Repr, +) + + +if TYPE_CHECKING: + from syrupy.types import ( + PropertyMatcher, + PropertyPath, + PropertyValueType, + SerializableData, + ) + + +class PathTypeError(TypeError): + pass + + +def path_type( + mapping: Optional[Dict[str, Tuple["PropertyValueType", ...]]] = None, + *, + types: Tuple["PropertyValueType", ...] = (), + strict: bool = True, +) -> "PropertyMatcher": + """ + Factory to create a matcher using path and type mapping + """ + if not mapping and not types: + raise PathTypeError(gettext("Both mapping and types argument cannot be empty")) + + def path_type_matcher( + data: "SerializableData", path: "PropertyPath" + ) -> Optional["SerializableData"]: + path_str = ".".join(str(p) for p, _ in path) + if mapping: + for path_to_match in mapping: + if path_to_match == path_str: + for type_to_match in mapping[path_to_match]: + if isinstance(data, type_to_match): + return Repr(DataSerializer.object_type(data)) + if strict: + raise PathTypeError( + gettext( + "{} at '{}' of type {} does not " + "match any of the expected types: {}" + ).format(data, path_str, data.__class__, types) + ) + for type_to_match in types: + if isinstance(data, type_to_match): + return Repr(DataSerializer.object_type(data)) + return data + + return path_type_matcher diff --git a/src/syrupy/report.py b/src/syrupy/report.py index 2e750578..997da140 100644 --- a/src/syrupy/report.py +++ b/src/syrupy/report.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: - from .assertion import SnapshotAssertion # noqa: F401 + from .assertion import SnapshotAssertion @attr.s diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 6c33b919..45f4b287 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from .assertion import SnapshotAssertion - from .extensions.base import AbstractSyrupyExtension # noqa: F401 + from .extensions.base import AbstractSyrupyExtension @attr.s diff --git a/src/syrupy/types.py b/src/syrupy/types.py index 1803cf43..ae84a929 100644 --- a/src/syrupy/types.py +++ b/src/syrupy/types.py @@ -1,8 +1,18 @@ from typing import ( Any, + Callable, + Hashable, + Optional, + Tuple, + Type, Union, ) SerializableData = Any SerializedData = Union[str, bytes] +PropertyName = Hashable +PropertyValueType = Type[SerializableData] +PropertyPathEntry = Tuple[PropertyName, PropertyValueType] +PropertyPath = Tuple[PropertyPathEntry, ...] +PropertyMatcher = Callable[..., Optional[SerializableData]] diff --git a/src/syrupy/utils.py b/src/syrupy/utils.py index 8010f37c..1d2bae23 100644 --- a/src/syrupy/utils.py +++ b/src/syrupy/utils.py @@ -1,3 +1,4 @@ +from gettext import gettext from importlib import import_module from pathlib import Path from typing import ( @@ -28,13 +29,19 @@ def import_module_member(path: str) -> Any: if not module_name: raise FailedToLoadModuleMember( - f"Cannot load member '{module_member_name}' without module path" + gettext("Cannot load member '{}' without module path").format( + module_member_name, + ) ) try: return getattr(import_module(module_name), module_member_name) except ModuleNotFoundError: - raise FailedToLoadModuleMember(f"Module '{module_name}' does not exist.") + raise FailedToLoadModuleMember( + gettext("Module '{}' does not exist.").format(module_name) + ) except AttributeError: raise FailedToLoadModuleMember( - f"Member '{module_member_name}' not found in module '{module_name}'." + gettext("Member '{}' not found in module '{}'.").format( + module_member_name, module_name, + ) ) diff --git a/tests/__snapshots__/test_extension_amber.ambr b/tests/__snapshots__/test_extension_amber.ambr index c57e173b..014a7673 100644 --- a/tests/__snapshots__/test_extension_amber.ambr +++ b/tests/__snapshots__/test_extension_amber.ambr @@ -213,6 +213,32 @@ ' --- +# name: test_non_deterministic_snapshots + { + 'a': UUID(...), + 'b': { + 'b_1': 'This is deterministic', + 'b_2': datetime.datetime(...), + }, + 'c': [ + 'Your wish is my command', + 'Do not replace this one', + ], + } +--- +# name: test_non_deterministic_snapshots.1 + { + 'a': UUID('06335e84-2872-4914-8c5d-3ed07d2a2f16'), + 'b': { + 'b_1': 'This is deterministic', + 'b_2': datetime.datetime(2020, 5, 31, 0, 0), + }, + 'c': [ + 'Replace this one', + 'Do not replace this one', + ], + } +--- # name: test_numbers 3.5 --- diff --git a/tests/__snapshots__/test_matchers.ambr b/tests/__snapshots__/test_matchers.ambr new file mode 100644 index 00000000..0ef5a4f5 --- /dev/null +++ b/tests/__snapshots__/test_matchers.ambr @@ -0,0 +1,19 @@ +# name: test_matches_expected_type + { + 'date_created': , + 'nested': { + 'id': , + }, + 'some_uuid': , + } +--- +# name: test_raises_unexpected_type + { + 'date_created': , + 'date_updated': datetime.date(2020, 6, 1), + 'nested': { + 'id': , + }, + 'some_uuid': , + } +--- diff --git a/tests/test_extension_amber.py b/tests/test_extension_amber.py index b38a2aea..c772de3b 100644 --- a/tests/test_extension_amber.py +++ b/tests/test_extension_amber.py @@ -1,7 +1,11 @@ +import uuid from collections import namedtuple +from datetime import datetime import pytest +from syrupy.extensions.amber.serializer import Repr + def test_non_snapshots(snapshot): with pytest.raises(AssertionError): @@ -181,3 +185,28 @@ def test_parameter_with_dot(parameter_with_dot, snapshot): def test_doubly_parametrized(parameter_1, parameter_2, snapshot): assert parameter_1 == snapshot assert parameter_2 == snapshot + + +def test_non_deterministic_snapshots(snapshot): + def matcher(data, path): + if isinstance(data, uuid.UUID): + return Repr("UUID(...)") + if isinstance(data, datetime): + return Repr("datetime.datetime(...)") + if tuple(p for p, _ in path[-2:]) == ("c", 0): + return "Your wish is my command" + return data + + assert { + "a": uuid.uuid4(), + "b": {"b_1": "This is deterministic", "b_2": datetime.now()}, + "c": ["Replace this one", "Do not replace this one"], + } == snapshot(matcher=matcher) + assert { + "a": uuid.UUID("06335e84-2872-4914-8c5d-3ed07d2a2f16"), + "b": { + "b_1": "This is deterministic", + "b_2": datetime(year=2020, month=5, day=31), + }, + "c": ["Replace this one", "Do not replace this one"], + } == snapshot diff --git a/tests/test_extension_image.py b/tests/test_extension_image.py index f538103c..38c496cd 100644 --- a/tests/test_extension_image.py +++ b/tests/test_extension_image.py @@ -50,7 +50,7 @@ def test_multiple_snapshot_extensions(snapshot): These should be indexed in order of assertion. """ assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension) - assert actual_svg == snapshot # uses initial extension class + assert actual_svg == snapshot() # uses initial extension class assert snapshot._extension is not None assert actual_png == snapshot(extension_class=PNGImageSnapshotExtension) assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension) diff --git a/tests/test_integration_custom.py b/tests/test_integration_custom.py index 4dd310e1..243180f6 100644 --- a/tests/test_integration_custom.py +++ b/tests/test_integration_custom.py @@ -15,7 +15,7 @@ class CustomSnapshotExtension(AbstractSyrupyExtension): def _file_extension(self): return "" - def serialize(self, data): + def serialize(self, data, **kwargs): return str(data) def get_snapshot_name(self, *, index = 0): diff --git a/tests/test_matchers.py b/tests/test_matchers.py new file mode 100644 index 00000000..8ada2e6a --- /dev/null +++ b/tests/test_matchers.py @@ -0,0 +1,46 @@ +import datetime +import uuid + +import pytest + +from syrupy.matchers import ( + PathTypeError, + path_type, +) + + +def test_matcher_path_type_noop(snapshot): + with pytest.raises(PathTypeError, match="argument cannot be empty"): + path_type() + + +def test_matches_expected_type(snapshot): + my_matcher = path_type( + {"date_created": (datetime.datetime,), "nested.id": (int,)}, types=(uuid.UUID,) + ) + actual = { + "date_created": datetime.datetime.now(), + "nested": {"id": 4}, + "some_uuid": uuid.uuid4(), + } + assert actual == snapshot(matcher=my_matcher) + + +def test_raises_unexpected_type(snapshot): + kwargs = { + "mapping": { + "date_created": (datetime.datetime,), + "date_updated": (datetime.datetime,), + "nested.id": (str,), + }, + "types": (uuid.UUID, int), + } + actual = { + "date_created": datetime.datetime.now(), + "date_updated": datetime.date(2020, 6, 1), + "nested": {"id": 4}, + "some_uuid": uuid.uuid4(), + } + assert actual == snapshot(matcher=path_type(**kwargs, strict=False)) + with pytest.raises(PathTypeError, match="does not match any of the expected"): + assert actual == snapshot(matcher=path_type(**kwargs))