Skip to content

Commit

Permalink
feat: add support for custom snapshot names, close #555 (#563)
Browse files Browse the repository at this point in the history
The snapshot_name_suffix will hold an optional suffix to be used instead
of index in the snapshot_name. The suffix will by default be formatted to
be between brackets: "[<snapshot_name_suffix>]"

Co-authored-by: Ouail Bendidi <ouail.bendidi@gmail.com>
  • Loading branch information
Noah and obendidi committed Nov 3, 2021
1 parent 21cfed8 commit 81a8a45
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 22 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,18 @@ Syrupy comes with a few built-in preset configurations for you to choose from. Y
- **`PNGSnapshotExtension`**: An extension of single file, this should be used to produce `.png` files from a byte string.
- **`SVGSnapshotExtension`**: Another extension of single file. This produces `.svg` files from an svg string.

#### `name`

By default, if you make multiple snapshot assertions within a single test case, an auto-increment identifier will be used to index the snapshots. You can override this behaviour by specifying a custom snapshot name to use in place of the auto-increment number.

```py
def test_case(snapshot):
assert "actual" == snapshot(name="case_a")
assert "other" == snapshot(name="case_b")
```

> _Warning_: If you use a custom name, you must make sure the name is not re-used within a test case.
### Advanced Usage

By overriding the provided [`AbstractSnapshotExtension`](https://github.com/tophat/syrupy/tree/master/src/syrupy/extensions/base.py) you can implement varied custom behaviours.
Expand Down
24 changes: 18 additions & 6 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
List,
Optional,
Type,
Union,
)

import attr
Expand Down Expand Up @@ -54,6 +55,7 @@ class SnapshotAssertion:
_exclude: Optional["PropertyFilter"] = attr.ib(
init=False, default=None, kw_only=True
)
_custom_index: Optional[str] = attr.ib(init=False, default=None, kw_only=True)
_extension: Optional["AbstractSyrupyExtension"] = attr.ib(
init=False, default=None, kw_only=True
)
Expand Down Expand Up @@ -90,6 +92,12 @@ def num_executions(self) -> int:
def executions(self) -> Dict[int, AssertionResult]:
return self._execution_results

@property
def index(self) -> Union[str, int]:
if self._custom_index:
return self._custom_index
return self.num_executions

def use_extension(
self, extension_class: Optional[Type["AbstractSyrupyExtension"]] = None
) -> "SnapshotAssertion":
Expand Down Expand Up @@ -149,6 +157,7 @@ def __call__(
exclude: Optional["PropertyFilter"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional[str] = None,
) -> "SnapshotAssertion":
"""
Modifies assertion instance options
Expand All @@ -159,6 +168,8 @@ def __call__(
self.__with_prop("_extension", self.__init_extension(extension_class))
if matcher:
self.__with_prop("_matcher", matcher)
if name:
self.__with_prop("_custom_index", name)
return self

def __dir__(self) -> List[str]:
Expand All @@ -168,21 +179,22 @@ def __eq__(self, other: "SerializableData") -> bool:
return self._assert(other)

def _assert(self, data: "SerializableData") -> bool:
snapshot_location = self.extension.get_location(index=self.num_executions)
snapshot_name = self.extension.get_snapshot_name(index=self.num_executions)
snapshot_location = self.extension.get_location(index=self.index)
snapshot_name = self.extension.get_snapshot_name(index=self.index)
snapshot_data: Optional["SerializedData"] = None
serialized_data: Optional["SerializedData"] = None
matches = False
assertion_success = False
assertion_exception = None
try:
snapshot_data = self._recall_data(index=self.num_executions)
snapshot_data = self._recall_data()
serialized_data = self._serialize(data)
matches = snapshot_data is not None and serialized_data == snapshot_data
assertion_success = matches
if not matches and self._update_snapshots:
self.extension.write_snapshot(
data=serialized_data, index=self.num_executions
data=serialized_data,
index=self.index,
)
assertion_success = True
return assertion_success
Expand Down Expand Up @@ -212,8 +224,8 @@ def _post_assert(self) -> None:
while self._post_assert_actions:
self._post_assert_actions.pop()()

def _recall_data(self, index: int) -> Optional["SerializableData"]:
def _recall_data(self) -> Optional["SerializableData"]:
try:
return self.extension.read_snapshot(index=index)
return self.extension.read_snapshot(index=self.index)
except SnapshotDoesNotExist:
return None
29 changes: 18 additions & 11 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
List,
Optional,
Set,
Union,
)

from syrupy.constants import (
Expand Down Expand Up @@ -73,12 +74,16 @@ class SnapshotFossilizer(ABC):
def test_location(self) -> "PyTestLocation":
raise NotImplementedError

def get_snapshot_name(self, *, index: int = 0) -> str:
def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
"""Get the snapshot name for the assertion index in a test location"""
index_suffix = f".{index}" if index > 0 else ""
index_suffix = ""
if isinstance(index, (str,)):
index_suffix = f"[{index}]"
elif index:
index_suffix = f".{index}"
return f"{self.test_location.snapshot_name}{index_suffix}"

def get_location(self, *, index: int) -> str:
def get_location(self, *, index: Union[str, int]) -> str:
"""Returns full location where snapshot data is stored."""
basename = self._get_file_basename(index=index)
fileext = f".{self._file_extension}" if self._file_extension else ""
Expand All @@ -105,7 +110,7 @@ def discover_snapshots(self) -> "SnapshotFossils":

return discovered

def read_snapshot(self, *, index: int) -> "SerializedData":
def read_snapshot(self, *, index: Union[str, int]) -> "SerializedData":
"""
Utility method for reading the contents of a snapshot assertion.
Will call `_pre_read`, then perform `read` and finally `post_read`,
Expand All @@ -127,7 +132,7 @@ def read_snapshot(self, *, index: int) -> "SerializedData":
finally:
self._post_read(index=index)

def write_snapshot(self, *, data: "SerializedData", index: int) -> None:
def write_snapshot(self, *, data: "SerializedData", index: Union[str, int]) -> None:
"""
Utility method for writing the contents of a snapshot assertion.
Will call `_pre_write`, then perform `write` and finally `_post_write`.
Expand Down Expand Up @@ -173,16 +178,18 @@ def delete_snapshots(
"""
raise NotImplementedError

def _pre_read(self, *, index: int = 0) -> None:
def _pre_read(self, *, index: Union[str, int] = 0) -> None:
pass

def _post_read(self, *, index: int = 0) -> None:
def _post_read(self, *, index: Union[str, int] = 0) -> None:
pass

def _pre_write(self, *, data: "SerializedData", index: int = 0) -> None:
def _pre_write(self, *, data: "SerializedData", index: Union[str, int] = 0) -> None:
self.__ensure_snapshot_dir(index=index)

def _post_write(self, *, data: "SerializedData", index: int = 0) -> None:
def _post_write(
self, *, data: "SerializedData", index: Union[str, int] = 0
) -> None:
pass

@abstractmethod
Expand Down Expand Up @@ -218,11 +225,11 @@ def _dirname(self) -> str:
def _file_extension(self) -> str:
raise NotImplementedError

def _get_file_basename(self, *, index: int) -> str:
def _get_file_basename(self, *, index: Union[str, int]) -> str:
"""Returns file basename without extension. Used to create full filepath."""
return self.test_location.filename

def __ensure_snapshot_dir(self, *, index: int) -> None:
def __ensure_snapshot_dir(self, *, index: Union[str, int]) -> None:
"""
Ensures the folder path for the snapshot file exists.
"""
Expand Down
5 changes: 3 additions & 2 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TYPE_CHECKING,
Optional,
Set,
Union,
)
from unicodedata import category

Expand Down Expand Up @@ -33,7 +34,7 @@ def serialize(
) -> "SerializedData":
return bytes(data)

def get_snapshot_name(self, *, index: int = 0) -> str:
def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
return self.__clean_filename(
super(SingleFileSnapshotExtension, self).get_snapshot_name(index=index)
)
Expand All @@ -47,7 +48,7 @@ def delete_snapshots(
def _file_extension(self) -> str:
return "raw"

def _get_file_basename(self, *, index: int) -> str:
def _get_file_basename(self, *, index: Union[str, int]) -> str:
return self.get_snapshot_name(index=index)

@property
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# name: test_snapshot_custom_snapshot_name_suffix[test_is_amazing]
'Syrupy is amazing!'
---
# name: test_snapshot_custom_snapshot_name_suffix[test_is_awesome]
'Syrupy is awesome!'
---
25 changes: 25 additions & 0 deletions tests/examples/test_custom_image_name_suffix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import base64

import pytest

from syrupy.extensions.image import PNGImageSnapshotExtension


@pytest.fixture
def snapshot(snapshot):
return snapshot.use_extension(PNGImageSnapshotExtension)


def test_png_image_with_custom_name_suffix(snapshot):
reddish_square = base64.b64decode(
b"iVBORw0KGgoAAAANSUhEUgAAAAQAAAAECAIAAAAmkwkpAAAAIUlEQVQIHTXB"
b"MQEAAAABQUYtvpD+dUzu3KBzg84NOjfoBjmmAd3WpSsrAAAAAElFTkSuQmCC"
)

blueish_square = base64.b64decode(
b"iVBORw0KGgoAAAANSUhEUgAAAAQAAAAECAIAAAAmkwkpAAAAIUlEQVQIHTXB"
b"MQEAAAABQUYtvpD+dUzuTKozqc6kOpPqBjg+Ad2g/BLMAAAAAElFTkSuQmCC"
)

assert blueish_square == snapshot(name="blueish")
assert reddish_square == snapshot(name="reddish")
4 changes: 3 additions & 1 deletion tests/examples/test_custom_snapshot_name.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""
Example: Custom Snapshot Name
"""
from typing import Union

import pytest

from syrupy.extensions.amber import AmberSnapshotExtension


class CanadianNameExtension(AmberSnapshotExtension):
def get_snapshot_name(self, *, index: int = 0) -> str:
def get_snapshot_name(self, *, index: Union[str, int]) -> str:
original_name = super(CanadianNameExtension, self).get_snapshot_name(
index=index
)
Expand Down
3 changes: 3 additions & 0 deletions tests/examples/test_custom_snapshot_name_suffix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def test_snapshot_custom_snapshot_name_suffix(snapshot):
assert "Syrupy is amazing!" == snapshot(name="test_is_amazing")
assert "Syrupy is awesome!" == snapshot(name="test_is_awesome")
63 changes: 63 additions & 0 deletions tests/integration/test_snapshot_option_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest


@pytest.fixture
def testcases():
return {
"base": (
"""
def test_a(snapshot):
assert snapshot(name="xyz") == "case 1"
assert snapshot(name="zyx") == "case 2"
"""
),
"modified": (
"""
def test_a(snapshot):
assert snapshot(name="xyz") == "case 1"
assert snapshot(name="zyx") == "case ??"
"""
),
}


@pytest.fixture
def run_testcases(testdir, testcases):
testdir.makepyfile(test_1=testcases["base"])
result = testdir.runpytest(
"-v",
"--snapshot-update",
)
result.stdout.re_match_lines((r"2 snapshots generated\."))
return testdir, testcases


def test_run_all(run_testcases):
testdir, testcases = run_testcases
result = testdir.runpytest(
"-v",
)
result.stdout.re_match_lines("2 snapshots passed")
assert result.ret == 0


def test_failure(run_testcases):
testdir, testcases = run_testcases
testdir.makepyfile(test_1=testcases["modified"])
result = testdir.runpytest(
"-v",
)
result.stdout.re_match_lines("1 snapshot failed. 1 snapshot passed.")
assert result.ret == 1


def test_update(run_testcases):
testdir, testcases = run_testcases
testdir.makepyfile(test_1=testcases["modified"])
result = testdir.runpytest(
"-v",
"--snapshot-update",
)
assert "Can not relate snapshot name" not in str(result.stdout)
result.stdout.re_match_lines("1 snapshot passed. 1 snapshot updated.")
assert result.ret == 0
4 changes: 2 additions & 2 deletions tests/integration/test_snapshot_use_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def _file_extension(self):
def serialize(self, data, **kwargs):
return str(data)
def get_snapshot_name(self, *, index = 0):
def get_snapshot_name(self, *, index):
testname = self._test_location.testname[::-1]
return f"{testname}.{index}"
def _get_file_basename(self, *, index = 0):
def _get_file_basename(self, *, index):
return self.test_location.filename[::-1]
@pytest.fixture
Expand Down

0 comments on commit 81a8a45

Please sign in to comment.