Skip to content

Commit

Permalink
Refactor for simplification (#9054)
Browse files Browse the repository at this point in the history
* Refactor for simplification

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] committed Aug 20, 2022
1 parent f258cf8 commit c725511
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 15 deletions.
2 changes: 1 addition & 1 deletion utils/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
except Exception as e: # url2
file.unlink(missing_ok=True) # remove partial downloads
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally:
if not file.exists() or file.stat().st_size < min_bytes: # check
file.unlink(missing_ok=True) # remove partial downloads
Expand Down
5 changes: 3 additions & 2 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def url2file(url):


def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
# Multi-threaded file download and unzip function, used in data.yaml for autodownload
# Multithreaded file download and unzip function, used in data.yaml for autodownload
def download_one(url, dir):
# Download 1 file
success = True
Expand All @@ -594,7 +594,8 @@ def download_one(url, dir):
for i in range(retry + 1):
if curl:
s = 'sS' if threads > 1 else '' # silent
r = os.system(f'curl -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
r = os.system(
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
success = r == 0
else:
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
Expand Down
2 changes: 1 addition & 1 deletion utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def process_batch(self, detections, labels):
"""
if detections is None:
gt_classes = labels.int()
for i, gc in enumerate(gt_classes):
for gc in gt_classes:
self.matrix[self.nc, gc] += 1 # background FN
return

Expand Down
8 changes: 3 additions & 5 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Plotting utils
"""

import contextlib
import math
import os
from copy import copy
Expand Down Expand Up @@ -180,8 +181,7 @@ def output_to_target(output):
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
targets = []
for i, o in enumerate(output):
for *box, conf, cls in o.cpu().numpy():
targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
targets.extend([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf] for *box, conf, cls in o.cpu().numpy())
return np.array(targets)


Expand Down Expand Up @@ -357,10 +357,8 @@ def plot_labels(labels, names=(), save_dir=Path('')):
matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
try: # color histogram bars by class
with contextlib.suppress(Exception): # color histogram bars by class
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
except Exception:
pass
ax[0].set_ylabel('instances')
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
Expand Down
11 changes: 5 additions & 6 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ def decorate(fn):
def smartCrossEntropyLoss(label_smoothing=0.0):
# Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
if check_version(torch.__version__, '1.10.0'):
return nn.CrossEntropyLoss(label_smoothing=label_smoothing) # loss function
else:
if label_smoothing > 0:
LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0')
return nn.CrossEntropyLoss() # loss function
return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
if label_smoothing > 0:
LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0')
return nn.CrossEntropyLoss()


def smart_DDP(model):
Expand Down Expand Up @@ -118,7 +117,7 @@ def select_device(device='', batch_size=0, newline=True):
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"

if not (cpu or mps) and torch.cuda.is_available(): # prefer GPU if available
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
Expand Down

0 comments on commit c725511

Please sign in to comment.