From 2dabc58ec3d133fe3a24de1cdbed5db37e376506 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 14 Mar 2021 23:16:17 -0700 Subject: [PATCH] PyTorch Hub models default to CUDA:0 if available (#2472) * PyTorch Hub models default to CUDA:0 if available * device as string bug fix --- hubconf.py | 4 +++- utils/datasets.py | 4 ++-- utils/general.py | 2 +- utils/torch_utils.py | 6 +++--- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/hubconf.py b/hubconf.py index e51ac90da36c..b7b740d39c06 100644 --- a/hubconf.py +++ b/hubconf.py @@ -12,6 +12,7 @@ from models.yolo import Model from utils.general import set_logging from utils.google_utils import attempt_download +from utils.torch_utils import select_device dependencies = ['torch', 'yaml'] set_logging() @@ -43,7 +44,8 @@ def create(name, pretrained, channels, classes, autoshape): model.names = ckpt['model'].names # set class names attribute if autoshape: model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS - return model + device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available + return model.to(device) except Exception as e: help_url = 'https://github.com/ultralytics/yolov5/issues/36' diff --git a/utils/datasets.py b/utils/datasets.py index 9a4b3f9fcc9f..86d7be39bec0 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -385,7 +385,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r # Display cache nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total if exists: - d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" + d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' @@ -485,7 +485,7 @@ def cache_labels(self, path=Path('./labels.cache'), prefix=''): nc += 1 print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') - pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \ + pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \ f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" if nf == 0: diff --git a/utils/general.py b/utils/general.py index e1c14bdaa4b3..621df64c6cf1 100755 --- a/utils/general.py +++ b/utils/general.py @@ -79,7 +79,7 @@ def check_git_status(): f"Use 'git pull' to update or 'git clone {url}' to download latest." else: s = f'up to date with {url} ✅' - print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) + print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe except Exception as e: print(e) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 806d29470e55..8f3538ab152a 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -1,8 +1,8 @@ # PyTorch utils - import logging import math import os +import platform import subprocess import time from contextlib import contextmanager @@ -53,7 +53,7 @@ def git_describe(): def select_device(device='', batch_size=None): # device = 'cpu' or '0' or '0,1,2,3' - s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string + s = f'YOLOv5 🚀 {git_describe()} torch {torch.__version__} ' # string cpu = device.lower() == 'cpu' if cpu: os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False @@ -73,7 +73,7 @@ def select_device(device='', batch_size=None): else: s += 'CPU\n' - logger.info(s) # skip a line + logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe return torch.device('cuda:0' if cuda else 'cpu')