From c725511bfc14eb86daf6edefa0d257084aa24c85 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 21 Aug 2022 01:34:03 +0200 Subject: [PATCH] Refactor for simplification (#9054) * Refactor for simplification * cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- utils/downloads.py | 2 +- utils/general.py | 5 +++-- utils/metrics.py | 2 +- utils/plots.py | 8 +++----- utils/torch_utils.py | 11 +++++------ 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/utils/downloads.py b/utils/downloads.py index c4d4a85c38ae..69887a579966 100644 --- a/utils/downloads.py +++ b/utils/downloads.py @@ -46,7 +46,7 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): except Exception as e: # url2 file.unlink(missing_ok=True) # remove partial downloads LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...') - os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail + os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail finally: if not file.exists() or file.stat().st_size < min_bytes: # check file.unlink(missing_ok=True) # remove partial downloads diff --git a/utils/general.py b/utils/general.py index 42d000918c13..d9f436a36359 100755 --- a/utils/general.py +++ b/utils/general.py @@ -582,7 +582,7 @@ def url2file(url): def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3): - # Multi-threaded file download and unzip function, used in data.yaml for autodownload + # Multithreaded file download and unzip function, used in data.yaml for autodownload def download_one(url, dir): # Download 1 file success = True @@ -594,7 +594,8 @@ def download_one(url, dir): for i in range(retry + 1): if curl: s = 'sS' if threads > 1 else '' # silent - r = os.system(f'curl -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue + r = os.system( + f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue success = r == 0 else: torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download diff --git a/utils/metrics.py b/utils/metrics.py index 08880cd3f212..8fa3c7e217c7 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -141,7 +141,7 @@ def process_batch(self, detections, labels): """ if detections is None: gt_classes = labels.int() - for i, gc in enumerate(gt_classes): + for gc in gt_classes: self.matrix[self.nc, gc] += 1 # background FN return diff --git a/utils/plots.py b/utils/plots.py index 7417308c4d82..2c7a80b4c872 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -3,6 +3,7 @@ Plotting utils """ +import contextlib import math import os from copy import copy @@ -180,8 +181,7 @@ def output_to_target(output): # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] targets = [] for i, o in enumerate(output): - for *box, conf, cls in o.cpu().numpy(): - targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf]) + targets.extend([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf] for *box, conf, cls in o.cpu().numpy()) return np.array(targets) @@ -357,10 +357,8 @@ def plot_labels(labels, names=(), save_dir=Path('')): matplotlib.use('svg') # faster ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) - try: # color histogram bars by class + with contextlib.suppress(Exception): # color histogram bars by class [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195 - except Exception: - pass ax[0].set_ylabel('instances') if 0 < len(names) < 30: ax[0].set_xticks(range(len(names))) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 4de2520b26a2..88108906bfd3 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -45,11 +45,10 @@ def decorate(fn): def smartCrossEntropyLoss(label_smoothing=0.0): # Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0 if check_version(torch.__version__, '1.10.0'): - return nn.CrossEntropyLoss(label_smoothing=label_smoothing) # loss function - else: - if label_smoothing > 0: - LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0') - return nn.CrossEntropyLoss() # loss function + return nn.CrossEntropyLoss(label_smoothing=label_smoothing) + if label_smoothing > 0: + LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0') + return nn.CrossEntropyLoss() def smart_DDP(model): @@ -118,7 +117,7 @@ def select_device(device='', batch_size=0, newline=True): assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \ f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)" - if not (cpu or mps) and torch.cuda.is_available(): # prefer GPU if available + if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 n = len(devices) # device count if n > 1 and batch_size > 0: # check batch_size is divisible by device_count