Skip to content

Commit

Permalink
fix download bug when use multi gpus (#13610)
Browse files Browse the repository at this point in the history
  • Loading branch information
changdazhou authored Aug 6, 2024
1 parent d7ea48e commit 20de659
Showing 1 changed file with 70 additions and 12 deletions.
82 changes: 70 additions & 12 deletions ppocr/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,91 @@

import os
import sys
import time
import shutil
import tarfile
import requests
import os.path as osp
import paddle.distributed as dist
from tqdm import tqdm

from ppocr.utils.logging import get_logger

MODELS_DIR = os.path.expanduser("~/.paddleocr/models/")
DOWNLOAD_RETRY_LIMIT = 3


def download_with_progressbar(url, save_path):
logger = get_logger()
if save_path and os.path.exists(save_path):
logger.info(f"Path {save_path} already exists. Skipping...")
return
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size_in_bytes = int(response.headers.get("content-length", 1))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
with open(save_path, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
else:
logger.error("Something went wrong while downloading models")
sys.exit(0)
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different nodes will download
# data, and the same node will only download data once.
if dist.get_rank() == 0:
_download(url, save_path)
else:
while not os.path.exists(save_path):
time.sleep(1)


def _download(url, save_path):
"""
Download from url, save to path.
url (str): download url
save_path (str): download to given path
"""
logger = get_logger()

fname = osp.split(url)[-1]
retry_cnt = 0

while not osp.exists(save_path):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError(
"Download from {} failed. " "Retry limit reached".format(url)
)

try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info(
"Downloading {} from {} failed {} times with exception {}".format(
fname, url, retry_cnt + 1, str(e)
)
)
time.sleep(1)
continue

if req.status_code != 200:
raise RuntimeError(
"Downloading from {} failed with code "
"{}!".format(url, req.status_code)
)

# For protecting download interupted, download to
# tmp_file firstly, move tmp_file to save_path
# after download finished
tmp_file = save_path + ".tmp"
total_size = req.headers.get("content-length")
with open(tmp_file, "wb") as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_file, save_path)

return save_path


def maybe_download(model_storage_directory, url):
Expand Down

0 comments on commit 20de659

Please sign in to comment.