Skip to content

Commit

Permalink
refactor: make write_snapshot a classmethod
Browse files Browse the repository at this point in the history
  • Loading branch information
Noah Negin-Ulster committed Dec 1, 2022
1 parent 6e7fc50 commit ae07435
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 49 deletions.
3 changes: 2 additions & 1 deletion src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def _read_snapshot_data_from_location(
snapshot = snapshots.get(snapshot_name)
return snapshot.data if snapshot else None

@classmethod
def _write_snapshot_collection(
self, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection"
) -> None:
DataSerializer.write_file(snapshot_collection, merge=True)

Expand Down
50 changes: 24 additions & 26 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ def get_snapshot_name(
index_suffix = f".{index}"
return f"{test_location.snapshot_name}{index_suffix}"

@classmethod
def get_location(
self, *, test_location: "PyTestLocation", index: "SnapshotIndex"
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex"
) -> str:
"""Returns full location where snapshot data is stored."""
basename = self._get_file_basename(test_location=test_location, index=index)
fileext = f".{self._file_extension}" if self._file_extension else ""
"""Returns full filepath where snapshot data is stored."""
basename = cls._get_file_basename(test_location=test_location, index=index)
fileext = f".{cls._file_extension}" if cls._file_extension else ""
return str(
Path(self.dirname(test_location=test_location)).joinpath(
Path(cls.dirname(test_location=test_location)).joinpath(
f"{basename}{fileext}"
)
)
Expand Down Expand Up @@ -155,8 +156,9 @@ def read_snapshot(
raise SnapshotDoesNotExist()
return snapshot_data

@classmethod
def write_snapshot(
self,
cls,
*,
test_location: "PyTestLocation",
snapshots: List[Tuple["SerializedData", "SnapshotIndex"]],
Expand All @@ -170,13 +172,19 @@ def write_snapshot(
# Amber extension.
locations: DefaultDict[str, List["Snapshot"]] = defaultdict(list)
for data, index in snapshots:
location = self.get_location(test_location=test_location, index=index)
snapshot_name = self.get_snapshot_name(
location = cls.get_location(test_location=test_location, index=index)
snapshot_name = cls.get_snapshot_name(
test_location=test_location, index=index
)
locations[location].append(Snapshot(name=snapshot_name, data=data))

self.__ensure_snapshot_dir(test_location=test_location, index=index)
# Ensures the folder path for the snapshot file exists.
try:
Path(
cls.get_location(test_location=test_location, index=index)
).parent.mkdir(parents=True)
except FileExistsError:
pass

for location, location_snapshots in locations.items():
snapshot_collection = SnapshotCollection(location=location)
Expand Down Expand Up @@ -208,7 +216,7 @@ def write_snapshot(
)
warnings.warn(warning_msg)

self._write_snapshot_collection(snapshot_collection=snapshot_collection)
cls._write_snapshot_collection(snapshot_collection=snapshot_collection)

@abstractmethod
def delete_snapshots(
Expand Down Expand Up @@ -238,38 +246,28 @@ def _read_snapshot_data_from_location(
"""
raise NotImplementedError

@classmethod
@abstractmethod
def _write_snapshot_collection(
self, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection"
) -> None:
"""
Adds the snapshot data to the snapshots in collection location
"""
raise NotImplementedError

def dirname(self, *, test_location: "PyTestLocation") -> str:
@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
test_dir = Path(test_location.filepath).parent
return str(test_dir.joinpath(SNAPSHOT_DIRNAME))

@classmethod
def _get_file_basename(
self, *, test_location: "PyTestLocation", index: "SnapshotIndex"
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex"
) -> str:
"""Returns file basename without extension. Used to create full filepath."""
return test_location.basename

def __ensure_snapshot_dir(
self, *, test_location: "PyTestLocation", index: "SnapshotIndex"
) -> None:
"""
Ensures the folder path for the snapshot file exists.
"""
try:
Path(
self.get_location(test_location=test_location, index=index)
).parent.mkdir(parents=True)
except FileExistsError:
pass


class SnapshotReporter(ABC):
_context_line_count = 1
Expand Down
41 changes: 23 additions & 18 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
) -> "SerializedData":
return self._supported_dataclass(data)
return self.get_supported_dataclass()(data)

@classmethod
def get_snapshot_name(
Expand All @@ -66,15 +66,15 @@ def delete_snapshots(
) -> None:
Path(snapshot_location).unlink()

@classmethod
def _get_file_basename(
self, *, test_location: "PyTestLocation", index: "SnapshotIndex"
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex"
) -> str:
return self.get_snapshot_name(test_location=test_location, index=index)
return cls.get_snapshot_name(test_location=test_location, index=index)

def dirname(self, *, test_location: "PyTestLocation") -> str:
original_dirname = super(SingleFileSnapshotExtension, self).dirname(
test_location=test_location
)
@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
original_dirname = AbstractSyrupyExtension.dirname(test_location=test_location)
return str(Path(original_dirname).joinpath(test_location.basename))

def _read_snapshot_collection(
Expand All @@ -89,41 +89,46 @@ def _read_snapshot_data_from_location(
) -> Optional["SerializableData"]:
try:
with open(
snapshot_location, f"r{self._write_mode}", encoding=self._write_encoding
snapshot_location,
f"r{self._write_mode}",
encoding=self.get_write_encoding(),
) as f:
return f.read()
except FileNotFoundError:
return None

@property
def _supported_dataclass(self) -> Union[Type[str], Type[bytes]]:
if self._write_mode == WriteMode.TEXT:
@classmethod
def get_supported_dataclass(cls) -> Union[Type[str], Type[bytes]]:
if cls._write_mode == WriteMode.TEXT:
return str
return bytes

@property
def _write_encoding(self) -> Optional[str]:
if self._write_mode == WriteMode.TEXT:
@classmethod
def get_write_encoding(cls) -> Optional[str]:
if cls._write_mode == WriteMode.TEXT:
return TEXT_ENCODING
return None

@classmethod
def _write_snapshot_collection(
self, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection"
) -> None:
filepath, data = (
snapshot_collection.location,
next(iter(snapshot_collection)).data,
)
if not isinstance(data, self._supported_dataclass):
if not isinstance(data, cls.get_supported_dataclass()):
error_text = gettext(
"Can't write non supported data. Expected '{}', got '{}'"
)
raise TypeError(
error_text.format(
self._supported_dataclass.__name__, type(data).__name__
cls.get_supported_dataclass().__name__, type(data).__name__
)
)
with open(filepath, f"w{self._write_mode}", encoding=self._write_encoding) as f:
with open(
filepath, f"w{cls._write_mode}", encoding=cls.get_write_encoding()
) as f:
f.write(data)

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion tests/examples/test_custom_snapshot_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@


class DifferentDirectoryExtension(AmberSnapshotExtension):
def dirname(self, *, test_location: "PyTestLocation") -> str:
@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
return str(Path(test_location.filepath).parent.joinpath(DIFFERENT_DIRECTORY))


Expand Down
3 changes: 2 additions & 1 deletion tests/examples/test_custom_snapshot_directory_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

def create_versioned_fixture(version: int):
class VersionedJSONExtension(JSONSnapshotExtension):
def dirname(self, *, test_location: "PyTestLocation") -> str:
@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
return str(
Path(test_location.filepath).parent.joinpath(
"__snapshots__", f"v{version}"
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_snapshot_outside_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def testcases(testdir, tmp_path):
from syrupy.extensions.amber import AmberSnapshotExtension
class CustomSnapshotExtension(AmberSnapshotExtension):
def dirname(self, *, test_location):
@classmethod
def dirname(cls, *, test_location):
return {str(dirname)!r}
@pytest.fixture
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_snapshot_use_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def get_snapshot_name(cls, *, test_location, index):
testname = test_location.testname[::-1]
return f"{testname}.{index}"
def _get_file_basename(self, *, test_location, index):
@classmethod
def _get_file_basename(cls, *, test_location, index):
return test_location.basename[::-1]
@pytest.fixture
Expand Down

2 comments on commit ae07435

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: ae07435 Previous: 23cca84 Ratio
benchmarks/test_1000x.py::test_1000x_reads 0.4553221013446547 iter/sec (stddev: 0.07968082308403533) 0.6754078596653935 iter/sec (stddev: 0.06391877117699159) 1.48
benchmarks/test_1000x.py::test_1000x_writes 0.1819143794981502 iter/sec (stddev: 0.08262196354046505) 0.6345993135561808 iter/sec (stddev: 0.23174880874067105) 3.49
benchmarks/test_standard.py::test_standard 0.40275715520513494 iter/sec (stddev: 0.22755538663534672) 0.6315599143065584 iter/sec (stddev: 0.0923523543680502) 1.57

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: ae07435 Previous: 23cca84 Ratio
benchmarks/test_1000x.py::test_1000x_writes 0.1819143794981502 iter/sec (stddev: 0.08262196354046505) 0.6345993135561808 iter/sec (stddev: 0.23174880874067105) 3.49

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.