Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor for simplification #9054

Merged
merged 3 commits into from
Aug 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion utils/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
except Exception as e: # url2
file.unlink(missing_ok=True) # remove partial downloads
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally:
if not file.exists() or file.stat().st_size < min_bytes: # check
file.unlink(missing_ok=True) # remove partial downloads
Expand Down
5 changes: 3 additions & 2 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def url2file(url):


def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
# Multi-threaded file download and unzip function, used in data.yaml for autodownload
# Multithreaded file download and unzip function, used in data.yaml for autodownload
def download_one(url, dir):
# Download 1 file
success = True
Expand All @@ -594,7 +594,8 @@ def download_one(url, dir):
for i in range(retry + 1):
if curl:
s = 'sS' if threads > 1 else '' # silent
r = os.system(f'curl -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
r = os.system(
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
success = r == 0
else:
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
Expand Down
2 changes: 1 addition & 1 deletion utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def process_batch(self, detections, labels):
"""
if detections is None:
gt_classes = labels.int()
for i, gc in enumerate(gt_classes):
for gc in gt_classes:
self.matrix[self.nc, gc] += 1 # background FN
return

Expand Down
8 changes: 3 additions & 5 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Plotting utils
"""

import contextlib
import math
import os
from copy import copy
Expand Down Expand Up @@ -180,8 +181,7 @@ def output_to_target(output):
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
targets = []
for i, o in enumerate(output):
for *box, conf, cls in o.cpu().numpy():
targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
targets.extend([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf] for *box, conf, cls in o.cpu().numpy())
return np.array(targets)


Expand Down Expand Up @@ -357,10 +357,8 @@ def plot_labels(labels, names=(), save_dir=Path('')):
matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
try: # color histogram bars by class
with contextlib.suppress(Exception): # color histogram bars by class
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
except Exception:
pass
ax[0].set_ylabel('instances')
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
Expand Down
11 changes: 5 additions & 6 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ def decorate(fn):
def smartCrossEntropyLoss(label_smoothing=0.0):
# Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
if check_version(torch.__version__, '1.10.0'):
return nn.CrossEntropyLoss(label_smoothing=label_smoothing) # loss function
else:
if label_smoothing > 0:
LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0')
return nn.CrossEntropyLoss() # loss function
return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
if label_smoothing > 0:
LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0')
return nn.CrossEntropyLoss()


def smart_DDP(model):
Expand Down Expand Up @@ -118,7 +117,7 @@ def select_device(device='', batch_size=0, newline=True):
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"

if not (cpu or mps) and torch.cuda.is_available(): # prefer GPU if available
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
Expand Down