diff --git a/README.md b/README.md index 37a667e0..36923cc0 100644 --- a/README.md +++ b/README.md @@ -288,6 +288,7 @@ See examples of how syrupy can be used and extended in the [test examples](https - [Custom snapshot directory](https://github.com/tophat/syrupy/tree/master/tests/examples/test_custom_snapshot_directory.py) - [Custom snapshot name](https://github.com/tophat/syrupy/tree/master/tests/examples/test_custom_snapshot_name.py) - [Custom object snapshots](https://github.com/tophat/syrupy/tree/master/tests/examples/test_custom_object_repr.py) +- [Custom comparator](https://github.com/tophat/syrupy/tree/master/tests/integration/test_custom_comparator.py) - [JPEG image extension](https://github.com/tophat/syrupy/tree/master/tests/examples/test_custom_image_extension.py) - [Built-in image extensions](https://github.com/tophat/syrupy/blob/master/tests/syrupy/extensions/image/test_image_svg.py) diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 5e627aa4..e5428e7a 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -189,7 +189,9 @@ def _assert(self, data: "SerializableData") -> bool: try: snapshot_data = self._recall_data() serialized_data = self._serialize(data) - matches = snapshot_data is not None and serialized_data == snapshot_data + matches = snapshot_data is not None and self.extension.matches( + serialized_data=serialized_data, snapshot_data=snapshot_data + ) assertion_success = matches if not matches and self._update_snapshots: self.extension.write_snapshot( diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 77f8e230..7a8930e0 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -378,7 +378,23 @@ def __strip_ends(self, line: str) -> str: return line.rstrip("".join(self._ends.keys())) -class AbstractSyrupyExtension(SnapshotSerializer, SnapshotFossilizer, SnapshotReporter): +class SnapshotComparator(ABC): + def matches( + self, + *, + serialized_data: "SerializableData", + snapshot_data: "SerializableData", + ) -> bool: + """ + Compares serialized data and snapshot data and returns + whether they match. + """ + return bool(serialized_data == snapshot_data) + + +class AbstractSyrupyExtension( + SnapshotSerializer, SnapshotFossilizer, SnapshotReporter, SnapshotComparator +): def __init__(self, test_location: "PyTestLocation"): self._test_location = test_location diff --git a/tests/integration/test_custom_comparator.py b/tests/integration/test_custom_comparator.py new file mode 100644 index 00000000..eee35d2d --- /dev/null +++ b/tests/integration/test_custom_comparator.py @@ -0,0 +1,84 @@ +import pytest + + +@pytest.fixture +def testcases_initial(testdir): + testdir.makeconftest( + """ + import pytest + import math + + from syrupy.extensions.amber import AmberSnapshotExtension + + class CustomSnapshotExtension(AmberSnapshotExtension): + def matches(self, *, serialized_data, snapshot_data): + try: + a = float(serialized_data) + b = float(snapshot_data) + return math.isclose(a, b, rel_tol=1e-5) + except: + return False + + @pytest.fixture + def snapshot_custom(snapshot): + return snapshot.use_extension(CustomSnapshotExtension) + """ + ) + return { + "passed": ( + """ + def test_passed_custom(snapshot_custom): + assert snapshot_custom == 3.0 + """ + ), + "failed": ( + """ + def test_passed_custom(snapshot_custom): + # this comment is required or the test breaks + assert snapshot_custom == 4.0 + """ + ), + } + + +@pytest.fixture +def generate_snapshots(testdir, testcases_initial): + testdir.makepyfile(test_file=testcases_initial["passed"]) + result = testdir.runpytest("-v", "--snapshot-update") + return result, testdir, testcases_initial + + +def test_generated_snapshots(generate_snapshots): + result = generate_snapshots[0] + result.stdout.re_match_lines((r"1 snapshot generated\.")) + assert "snapshots unused" not in result.stdout.str() + assert result.ret == 0 + + +def test_approximate_match(generate_snapshots): + testdir = generate_snapshots[1] + testdir.makepyfile( + test_file=""" + def test_passed_custom(snapshot_custom): + assert snapshot_custom == 3.2 + """ + ) + result = testdir.runpytest("-v") + result.stdout.re_match_lines((r"test_file.py::test_passed_custom PASSED")) + assert result.ret == 0 + + +def test_failed_snapshots(generate_snapshots): + testdir = generate_snapshots[1] + testdir.makepyfile(test_file=generate_snapshots[2]["failed"]) + result = testdir.runpytest("-v") + result.stdout.re_match_lines((r"1 snapshot failed\.")) + assert result.ret == 1 + + +def test_updated_snapshots(generate_snapshots): + _, testdir, initial = generate_snapshots + testdir.makepyfile(test_file=initial["failed"]) + result = testdir.runpytest("-v", "--snapshot-update") + result.stdout.re_match_lines((r"1 snapshot updated\.")) + assert result.ret == 0