diff --git a/anomalib/data/btech.py b/anomalib/data/btech.py index aef0664ca1..4c990cff13 100644 --- a/anomalib/data/btech.py +++ b/anomalib/data/btech.py @@ -41,7 +41,7 @@ from tqdm import tqdm from anomalib.data.inference import InferenceDataset -from anomalib.data.utils import DownloadProgressBar, read_image +from anomalib.data.utils import DownloadProgressBar, hash_check, read_image from anomalib.data.utils.split import ( create_validation_set_from_test_set, split_normal_images_in_train_set, @@ -359,7 +359,8 @@ def prepare_data(self) -> None: filename=zip_filename, reporthook=progress_bar.update_to, ) # nosec - + logger.info("Checking hash") + hash_check(zip_filename, "c1fa4d56ac50dd50908ce04e81037a8e") logger.info("Extracting the dataset.") with zipfile.ZipFile(zip_filename, "r") as zip_file: zip_file.extractall(self.root.parent) diff --git a/anomalib/data/mvtec.py b/anomalib/data/mvtec.py index 5ecb4419b7..efc0b892fd 100644 --- a/anomalib/data/mvtec.py +++ b/anomalib/data/mvtec.py @@ -57,7 +57,7 @@ from torchvision.datasets.folder import VisionDataset from anomalib.data.inference import InferenceDataset -from anomalib.data.utils import DownloadProgressBar, read_image +from anomalib.data.utils import DownloadProgressBar, hash_check, read_image from anomalib.data.utils.split import ( create_validation_set_from_test_set, split_normal_images_in_train_set, @@ -378,19 +378,22 @@ def prepare_data(self) -> None: logger.info("Downloading the Mvtec AD dataset.") url = "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094" dataset_name = "mvtec_anomaly_detection.tar.xz" + zip_filename = self.root / dataset_name with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc="MVTec AD") as progress_bar: urlretrieve( url=f"{url}/{dataset_name}", - filename=self.root / dataset_name, + filename=zip_filename, reporthook=progress_bar.update_to, ) + logger.info("Checking hash") + hash_check(zip_filename, "eefca59f2cede9c3fc5b6befbfec275e") logger.info("Extracting the dataset.") - with tarfile.open(self.root / dataset_name) as tar_file: + with tarfile.open(zip_filename) as tar_file: tar_file.extractall(self.root) logger.info("Cleaning the tar file") - (self.root / dataset_name).unlink() + (zip_filename).unlink() def setup(self, stage: Optional[str] = None) -> None: """Setup train, validation and test data. diff --git a/anomalib/data/utils/__init__.py b/anomalib/data/utils/__init__.py index c493058051..45c94c587f 100644 --- a/anomalib/data/utils/__init__.py +++ b/anomalib/data/utils/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions # and limitations under the License. -from .download import DownloadProgressBar +from .download import DownloadProgressBar, hash_check from .image import get_image_filenames, read_image -__all__ = ["get_image_filenames", "read_image", "DownloadProgressBar"] +__all__ = ["get_image_filenames", "hash_check", "read_image", "DownloadProgressBar"] diff --git a/anomalib/data/utils/download.py b/anomalib/data/utils/download.py index 26af24834a..c043aa1a77 100644 --- a/anomalib/data/utils/download.py +++ b/anomalib/data/utils/download.py @@ -1,7 +1,4 @@ -"""Helper to show progress bars with `urlretrieve`. - -Based on https://stackoverflow.com/a/53877507 -""" +"""Helper to show progress bars with `urlretrieve`, check hash of file.""" # Copyright (C) 2020 Intel Corporation # @@ -17,7 +14,9 @@ # See the License for the specific language governing permissions # and limitations under the License. +import hashlib import io +from pathlib import Path from typing import Dict, Iterable, Optional, Union from tqdm import tqdm @@ -146,7 +145,7 @@ def __init__( colour: Optional[str] = None, delay: Optional[float] = 0, gui: Optional[bool] = False, - **kwargs + **kwargs, ): super().__init__( iterable=iterable, @@ -175,13 +174,14 @@ def __init__( colour=colour, delay=delay, gui=gui, - **kwargs + **kwargs, ) self.total: Optional[Union[int, float]] def update_to(self, chunk_number: int = 1, max_chunk_size: int = 1, total_size=None): """Progress bar hook for tqdm. + Based on https://stackoverflow.com/a/53877507 The implementor does not have to bother about passing parameters to this as it gets them from urlretrieve. However the context needs a few parameters. Refer to the example. @@ -193,3 +193,16 @@ def update_to(self, chunk_number: int = 1, max_chunk_size: int = 1, total_size=N if total_size is not None: self.total = total_size self.update(chunk_number * max_chunk_size - self.n) + + +def hash_check(file_path: Path, expected_hash: str): + """Raise assert error if hash does not match the calculated hash of the file. + + Args: + file_path (Path): Path to file. + expected_hash (str): Expected hash of the file. + """ + with open(file_path, "rb") as hash_file: + assert ( + hashlib.md5(hash_file.read()).hexdigest() == expected_hash + ), f"Downloaded file {file_path} does not match the required hash."