From b1123cfd543d35d282de9bb28067e48ebec18afe Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 15 Feb 2024 16:37:11 +0100 Subject: [PATCH] cleanup download utils (#8273) --- torchvision/datasets/utils.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 46858666c0d..344056d67db 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -12,7 +12,7 @@ import urllib.error import urllib.request import zipfile -from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, IO, Iterable, List, Optional, Tuple, TypeVar, Union from urllib.parse import urlparse import numpy as np @@ -24,24 +24,12 @@ USER_AGENT = "pytorch/vision" -def _save_response_content( - content: Iterator[bytes], - destination: Union[str, pathlib.Path], - length: Optional[int] = None, -) -> None: - with open(destination, "wb") as fh, tqdm(total=length) as pbar: - for chunk in content: - # filter out keep-alive new chunks - if not chunk: - continue - - fh.write(chunk) - pbar.update(len(chunk)) - - def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None: with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: - _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length) + with open(filename, "wb") as fh, tqdm(total=response.length) as pbar: + while chunk := response.read(chunk_size): + fh.write(chunk) + pbar.update(len(chunk)) def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str: