Skip to content

Commit

Permalink
feat: add snapshot diffing support (#526)
Browse files Browse the repository at this point in the history
* feat: add snapshot diffing support

* docs: add readme section for diff assertion option

* refactor: replace with SnapshotIndex type alias
  • Loading branch information
iamogbz committed May 11, 2022
1 parent 3f32f4b commit e424f31
Show file tree
Hide file tree
Showing 10 changed files with 265 additions and 32 deletions.
20 changes: 18 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,23 @@ def test_foo(snapshot):
assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension)
```

#### `diff`

This is an option to snapshot only the diff between the actual object and a previous snapshot, with the `diff` argument being the previous snapshot `index`/`name`.

```py
def test_diff(snapshot):
actual0 = [1,2,3,4]
actual1 = [0,1,3,4]

assert actual0 == snapshot
assert actual1 == snapshot(diff=0)
# This is equivalent to the lines above
# Must use the index name to diff when given
assert actual0 == snapshot(name='snap_name')
assert actual1 == snapshot(diff='snap_name')
```

##### Built-In Extensions

Syrupy comes with a few built-in preset configurations for you to choose from. You should also feel free to extend the `AbstractSyrupyExtension` if your project has a need not captured by one our built-ins.
Expand Down Expand Up @@ -295,7 +312,7 @@ from syrupy.extensions.json import JSONSnapshotExtension

@pytest.fixture
def snapshot_json(snapshot):
return snapshot.use_extension(JSONSnapshotExtension)
return snapshot.use_extension(JSONSnapshotExtension)


def test_api_call(client, snapshot_json):
Expand Down Expand Up @@ -400,5 +417,4 @@ This section is automatically generated via tagging the all-contributors bot in

## License


Syrupy is licensed under [Apache License Version 2.0](https://github.com/tophat/syrupy/tree/master/LICENSE).
24 changes: 18 additions & 6 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
List,
Optional,
Type,
Union,
)

from .exceptions import SnapshotDoesNotExist
Expand All @@ -26,6 +25,7 @@
PropertyMatcher,
SerializableData,
SerializedData,
SnapshotIndex,
)


Expand Down Expand Up @@ -108,7 +108,7 @@ def executions(self) -> Dict[int, AssertionResult]:
return self._execution_results

@property
def index(self) -> Union[str, int]:
def index(self) -> "SnapshotIndex":
if self._custom_index:
return self._custom_index
return self.num_executions
Expand Down Expand Up @@ -169,10 +169,11 @@ def __with_prop(self, prop_name: str, prop_value: Any) -> None:
def __call__(
self,
*,
diff: Optional["SnapshotIndex"] = None,
exclude: Optional["PropertyFilter"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional[str] = None,
name: Optional["SnapshotIndex"] = None,
) -> "SnapshotAssertion":
"""
Modifies assertion instance options
Expand All @@ -185,6 +186,8 @@ def __call__(
self.__with_prop("_matcher", matcher)
if name:
self.__with_prop("_custom_index", name)
if diff is not None:
self.__with_prop("_snapshot_diff", diff)
return self

def __dir__(self) -> List[str]:
Expand All @@ -202,8 +205,17 @@ def _assert(self, data: "SerializableData") -> bool:
assertion_success = False
assertion_exception = None
try:
snapshot_data = self._recall_data()
snapshot_data = self._recall_data(index=self.index)
serialized_data = self._serialize(data)
snapshot_diff = getattr(self, "_snapshot_diff", None)
if snapshot_diff is not None:
snapshot_data_diff = self._recall_data(index=snapshot_diff)
if snapshot_data_diff is None:
raise SnapshotDoesNotExist()
serialized_data = self.extension.diff_snapshots(
serialized_data=serialized_data,
snapshot_data=snapshot_data_diff,
)
matches = snapshot_data is not None and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
)
Expand Down Expand Up @@ -241,8 +253,8 @@ def _post_assert(self) -> None:
while self._post_assert_actions:
self._post_assert_actions.pop()()

def _recall_data(self) -> Optional["SerializableData"]:
def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]:
try:
return self.extension.read_snapshot(index=self.index)
return self.extension.read_snapshot(index=index)
except SnapshotDoesNotExist:
return None
7 changes: 5 additions & 2 deletions src/syrupy/extensions/amber/serializer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import functools
import os
from types import GeneratorType
from types import (
GeneratorType,
MappingProxyType,
)
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -163,7 +166,7 @@ def _serialize(
serialize_method = cls.serialize_number
elif isinstance(data, (set, frozenset)):
serialize_method = cls.serialize_set
elif isinstance(data, dict):
elif isinstance(data, (dict, MappingProxyType)):
serialize_method = cls.serialize_dict
elif cls.__is_namedtuple(data):
serialize_method = cls.serialize_namedtuple
Expand Down
43 changes: 27 additions & 16 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
List,
Optional,
Set,
Union,
)

from syrupy.constants import (
DISABLE_COLOR_ENV_VAR,
SNAPSHOT_DIRNAME,
SYMBOL_CARRIAGE,
SYMBOL_ELLIPSIS,
Expand All @@ -40,7 +40,11 @@
snapshot_diff_style,
snapshot_style,
)
from syrupy.utils import walk_snapshot_dir
from syrupy.utils import (
env_context,
obj_attrs,
walk_snapshot_dir,
)

if TYPE_CHECKING:
from syrupy.location import PyTestLocation
Expand All @@ -49,6 +53,7 @@
PropertyMatcher,
SerializableData,
SerializedData,
SnapshotIndex,
)


Expand All @@ -74,7 +79,7 @@ class SnapshotFossilizer(ABC):
def test_location(self) -> "PyTestLocation":
raise NotImplementedError

def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
def get_snapshot_name(self, *, index: "SnapshotIndex" = 0) -> str:
"""Get the snapshot name for the assertion index in a test location"""
index_suffix = ""
if isinstance(index, (str,)):
Expand All @@ -83,7 +88,7 @@ def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
index_suffix = f".{index}"
return f"{self.test_location.snapshot_name}{index_suffix}"

def get_location(self, *, index: Union[str, int]) -> str:
def get_location(self, *, index: "SnapshotIndex") -> str:
"""Returns full location where snapshot data is stored."""
basename = self._get_file_basename(index=index)
fileext = f".{self._file_extension}" if self._file_extension else ""
Expand All @@ -110,7 +115,7 @@ def discover_snapshots(self) -> "SnapshotFossils":

return discovered

def read_snapshot(self, *, index: Union[str, int]) -> "SerializedData":
def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData":
"""
Utility method for reading the contents of a snapshot assertion.
Will call `_pre_read`, then perform `read` and finally `post_read`,
Expand All @@ -132,7 +137,7 @@ def read_snapshot(self, *, index: Union[str, int]) -> "SerializedData":
finally:
self._post_read(index=index)

def write_snapshot(self, *, data: "SerializedData", index: Union[str, int]) -> None:
def write_snapshot(self, *, data: "SerializedData", index: "SnapshotIndex") -> None:
"""
Utility method for writing the contents of a snapshot assertion.
Will call `_pre_write`, then perform `write` and finally `_post_write`.
Expand Down Expand Up @@ -178,17 +183,17 @@ def delete_snapshots(
"""
raise NotImplementedError

def _pre_read(self, *, index: Union[str, int] = 0) -> None:
def _pre_read(self, *, index: "SnapshotIndex" = 0) -> None:
pass

def _post_read(self, *, index: Union[str, int] = 0) -> None:
def _post_read(self, *, index: "SnapshotIndex" = 0) -> None:
pass

def _pre_write(self, *, data: "SerializedData", index: Union[str, int] = 0) -> None:
def _pre_write(self, *, data: "SerializedData", index: "SnapshotIndex" = 0) -> None:
self.__ensure_snapshot_dir(index=index)

def _post_write(
self, *, data: "SerializedData", index: Union[str, int] = 0
self, *, data: "SerializedData", index: "SnapshotIndex" = 0
) -> None:
pass

Expand Down Expand Up @@ -225,11 +230,11 @@ def _dirname(self) -> str:
def _file_extension(self) -> str:
raise NotImplementedError

def _get_file_basename(self, *, index: Union[str, int]) -> str:
def _get_file_basename(self, *, index: "SnapshotIndex") -> str:
"""Returns file basename without extension. Used to create full filepath."""
return self.test_location.filename

def __ensure_snapshot_dir(self, *, index: Union[str, int]) -> None:
def __ensure_snapshot_dir(self, *, index: "SnapshotIndex") -> None:
"""
Ensures the folder path for the snapshot file exists.
"""
Expand All @@ -240,6 +245,16 @@ def __ensure_snapshot_dir(self, *, index: Union[str, int]) -> None:


class SnapshotReporter(ABC):
_context_line_count = 1

def diff_snapshots(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
) -> "SerializedData":
env = {DISABLE_COLOR_ENV_VAR: "true"}
attrs = {"_context_line_count": 0}
with env_context(**env), obj_attrs(self, attrs):
return "\n".join(self.diff_lines(serialized_data, snapshot_data))

def diff_lines(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
) -> Iterator[str]:
Expand All @@ -250,10 +265,6 @@ def diff_lines(
def _ends(self) -> Dict[str, str]:
return {"\n": self._marker_new_line, "\r": self._marker_carriage}

@property
def _context_line_count(self) -> int:
return 1

@property
def _context_line_max(self) -> int:
return self._context_line_count * 2
Expand Down
6 changes: 3 additions & 3 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
TYPE_CHECKING,
Optional,
Set,
Union,
)
from unicodedata import category

Expand All @@ -21,6 +20,7 @@
PropertyMatcher,
SerializableData,
SerializedData,
SnapshotIndex,
)


Expand All @@ -34,7 +34,7 @@ def serialize(
) -> "SerializedData":
return bytes(data)

def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
def get_snapshot_name(self, *, index: "SnapshotIndex" = 0) -> str:
return self.__clean_filename(
super(SingleFileSnapshotExtension, self).get_snapshot_name(index=index)
)
Expand All @@ -48,7 +48,7 @@ def delete_snapshots(
def _file_extension(self) -> str:
return "raw"

def _get_file_basename(self, *, index: Union[str, int]) -> str:
def _get_file_basename(self, *, index: "SnapshotIndex") -> str:
return self.get_snapshot_name(index=index)

@property
Expand Down
1 change: 1 addition & 0 deletions src/syrupy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Union,
)

SnapshotIndex = Union[int, str]
SerializableData = Any
SerializedData = Union[str, bytes]
PropertyName = Hashable
Expand Down
15 changes: 15 additions & 0 deletions src/syrupy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import (
Any,
Dict,
Iterator,
)

Expand Down Expand Up @@ -66,3 +67,17 @@ def env_context(**kwargs: str) -> Iterator[None]:
finally:
os.environ.clear()
os.environ.update(prev_env)


def set_attrs(obj: Any, attrs: Dict[str, Any]) -> Any:
for k in attrs:
setattr(obj, k, attrs[k])


@contextmanager
def obj_attrs(obj: Any, attrs: Dict[str, Any]) -> Iterator[None]:
prev_attrs = {k: getattr(obj, k, None) for k in attrs}
try:
yield set_attrs(obj, attrs)
finally:
set_attrs(obj, prev_attrs)
5 changes: 2 additions & 3 deletions tests/examples/test_custom_snapshot_name.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""
Example: Custom Snapshot Name
"""
from typing import Union

import pytest

from syrupy.extensions.amber import AmberSnapshotExtension
from syrupy.types import SnapshotIndex


class CanadianNameExtension(AmberSnapshotExtension):
def get_snapshot_name(self, *, index: Union[str, int]) -> str:
def get_snapshot_name(self, *, index: "SnapshotIndex") -> str:
original_name = super(CanadianNameExtension, self).get_snapshot_name(
index=index
)
Expand Down
Loading

0 comments on commit e424f31

Please sign in to comment.