diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 4ef914f9..d05cb38c 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -43,6 +43,7 @@ class SnapshotAssertion: _extension_class: Type["AbstractSyrupyExtension"] = attr.ib(kw_only=True) _test_location: "TestLocation" = attr.ib(kw_only=True) _update_snapshots: bool = attr.ib(kw_only=True) + _extension: Optional["AbstractSyrupyExtension"] = attr.ib(init=False, default=None) _executions: int = attr.ib(init=False, default=0, kw_only=True) _execution_results: Dict[int, "AssertionResult"] = attr.ib( init=False, factory=dict, kw_only=True @@ -51,12 +52,15 @@ class SnapshotAssertion: def __attrs_post_init__(self) -> None: self._session.register_request(self) + def __init_extension( + self, extension_class: Type["AbstractSyrupyExtension"] + ) -> "AbstractSyrupyExtension": + return extension_class(test_location=self._test_location) + @property def extension(self) -> "AbstractSyrupyExtension": - if not getattr(self, "_extension", None): - self._extension: "AbstractSyrupyExtension" = self._extension_class( - test_location=self._test_location - ) + if not self._extension: + self._extension = self.__init_extension(self._extension_class) return self._extension @property @@ -70,6 +74,10 @@ def executions(self) -> Dict[int, AssertionResult]: def use_extension( self, extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, ) -> "SnapshotAssertion": + """ + Creates a new snapshot assertion fixture with the same options but using + specified extension class. This does not preserve assertion index or state. + """ return self.__class__( update_snapshots=self._update_snapshots, test_location=self._test_location, @@ -92,6 +100,16 @@ def get_assert_diff(self, data: "SerializableData") -> List[str]: diff.extend(self.extension.diff_lines(serialized_data, snapshot_data)) return diff + def __call__( + self, *, extension_class: Optional[Type["AbstractSyrupyExtension"]] + ) -> "SnapshotAssertion": + """ + Modifies assertion instance options + """ + if extension_class: + self._extension = self.__init_extension(extension_class) + return self + def __repr__(self) -> str: attrs_to_repr = ["name", "num_executions"] attrs_repr = ", ".join(f"{a}={repr(getattr(self, a))}" for a in attrs_to_repr) @@ -131,6 +149,13 @@ def _assert(self, data: "SerializableData") -> bool: updated=snapshot_updated, ) self._executions += 1 + self._post_assert() + + def _post_assert(self) -> None: + """ + Restores assertion instance options + """ + self._extension = None def _recall_data(self, index: int) -> Optional["SerializableData"]: try: diff --git a/tests/__snapshots__/test_extension_image.ambr b/tests/__snapshots__/test_extension_image.ambr new file mode 100644 index 00000000..6a118b13 --- /dev/null +++ b/tests/__snapshots__/test_extension_image.ambr @@ -0,0 +1,3 @@ +# name: test_multiple_snapshot_extensions.1 + '50 x 50' +--- diff --git a/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.2.png b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.2.png new file mode 100644 index 00000000..7eb2b9ad Binary files /dev/null and b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.2.png differ diff --git a/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.3.svg b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.3.svg new file mode 100644 index 00000000..90ebb8dd --- /dev/null +++ b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.3.svg @@ -0,0 +1 @@ +50 x 50 \ No newline at end of file diff --git a/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.svg b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.svg new file mode 100644 index 00000000..90ebb8dd --- /dev/null +++ b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.svg @@ -0,0 +1 @@ +50 x 50 \ No newline at end of file diff --git a/tests/test_extension_image.py b/tests/test_extension_image.py index 329d0742..af67e103 100644 --- a/tests/test_extension_image.py +++ b/tests/test_extension_image.py @@ -8,40 +8,48 @@ ) +actual_png = base64.b64decode( + b"iVBORw0KGgoAAAANSUhEUgAAADIAAAAyBAMAAADsEZWCAAAAG1BMVEXMzMy" + b"Wlpaqqqq3t7exsbGcnJy+vr6jo6PFxcUFpPI/AAAACXBIWXMAAA7EAAAOxA" + b"GVKw4bAAAAQUlEQVQ4jWNgGAWjgP6ASdncAEaiAhaGiACmFhCJLsMaIiDAE" + b"QEi0WXYEiMCOCJAJIY9KuYGTC0gknpuHwXDGwAA5fsIZw0iYWYAAAAASUVO" + b"RK5CYII=" +) +actual_svg = ( + '' + '' + '' + '' + '' + '' + '' + '50 x 50' +) + + @pytest.fixture def snapshot_png(snapshot): return snapshot.use_extension(PNGImageSnapshotExtension) -def test_image(snapshot_png, snapshot_svg): - actual_png = base64.b64decode( - b"iVBORw0KGgoAAAANSUhEUgAAADIAAAAyBAMAAADsEZWCAAAAG1BMVEXMzMy" - b"Wlpaqqqq3t7exsbGcnJy+vr6jo6PFxcUFpPI/AAAACXBIWXMAAA7EAAAOxA" - b"GVKw4bAAAAQUlEQVQ4jWNgGAWjgP6ASdncAEaiAhaGiACmFhCJLsMaIiDAE" - b"QEi0WXYEiMCOCJAJIY9KuYGTC0gknpuHwXDGwAA5fsIZw0iYWYAAAAASUVO" - b"RK5CYII=" - ) +def test_image(snapshot_png): assert actual_png == snapshot_png -@pytest.fixture -def snapshot_svg(snapshot): - return snapshot.use_extension(SVGImageSnapshotExtension) +def test_image_vector(snapshot): + """ + Example of creating a previewable svg snapshot + """ + assert snapshot(extension_class=SVGImageSnapshotExtension) == actual_svg -def test_image_vector(snapshot_svg): +def test_multiple_snapshot_extensions(snapshot): """ - Example of creating a previewable svg snapshot + Example of switching extension classes on the fly. + These should be indexed in order of assertion. """ - actual_svg = ( - '' - '' - '' - '' - '' - '' - '' - '50 x 50' - ) - assert snapshot_svg == actual_svg + assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension) + assert actual_svg == snapshot # uses initial extension class + assert actual_png == snapshot(extension_class=PNGImageSnapshotExtension) + assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension)