From 69b0faf2f7d43b32833ec4b46b2faf50b5b7ceed Mon Sep 17 00:00:00 2001 From: Yakuho Date: Wed, 3 Jan 2024 15:09:23 +0800 Subject: [PATCH] Update selectable device Profile (#12353) * Update selectable device Profile * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- classify/predict.py | 2 +- classify/val.py | 2 +- detect.py | 2 +- segment/predict.py | 2 +- segment/val.py | 2 +- utils/general.py | 7 ++++--- val.py | 2 +- 7 files changed, 10 insertions(+), 9 deletions(-) diff --git a/classify/predict.py b/classify/predict.py index 653c374f768f..b056a0cd707b 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -106,7 +106,7 @@ def run( # Run inference model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup - seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) + seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device)) for path, im, im0s, vid_cap, s in dataset: with dt[0]: im = torch.Tensor(im).to(model.device) diff --git a/classify/val.py b/classify/val.py index 4b92e9f105db..6814c4d780e1 100644 --- a/classify/val.py +++ b/classify/val.py @@ -97,7 +97,7 @@ def run( workers=workers) model.eval() - pred, targets, loss, dt = [], [], 0, (Profile(), Profile(), Profile()) + pred, targets, loss, dt = [], [], 0, (Profile(device=device), Profile(device=device), Profile(device=device)) n = len(dataloader) # number of batches action = 'validating' if dataloader.dataset.root.stem == 'val' else 'testing' desc = f'{pbar.desc[:-36]}{action:>36}' if pbar else f'{action}' diff --git a/detect.py b/detect.py index fd9637138dd6..1ea4e0b60dd7 100644 --- a/detect.py +++ b/detect.py @@ -116,7 +116,7 @@ def run( # Run inference model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup - seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) + seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device)) for path, im, im0s, vid_cap, s in dataset: with dt[0]: im = torch.from_numpy(im).to(model.device) diff --git a/segment/predict.py b/segment/predict.py index 113bc472e637..8e3d97dfeb92 100644 --- a/segment/predict.py +++ b/segment/predict.py @@ -117,7 +117,7 @@ def run( # Run inference model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup - seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) + seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device)) for path, im, im0s, vid_cap, s in dataset: with dt[0]: im = torch.from_numpy(im).to(model.device) diff --git a/segment/val.py b/segment/val.py index dc8081840e37..304d0c751314 100644 --- a/segment/val.py +++ b/segment/val.py @@ -233,7 +233,7 @@ def run( class_map = coco80_to_coco91_class() if is_coco else list(range(1000)) s = ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P', 'R', 'mAP50', 'mAP50-95)') - dt = Profile(), Profile(), Profile() + dt = Profile(device=device), Profile(device=device), Profile(device=device) metrics = Metrics() loss = torch.zeros(4, device=device) jdict, stats = [], [] diff --git a/utils/general.py b/utils/general.py index 135141e21436..73925ce5fb95 100644 --- a/utils/general.py +++ b/utils/general.py @@ -182,9 +182,10 @@ def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'): class Profile(contextlib.ContextDecorator): # YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager - def __init__(self, t=0.0): + def __init__(self, t=0.0, device: torch.device = None): self.t = t - self.cuda = torch.cuda.is_available() + self.device = device + self.cuda = True if (device and str(device)[:4] == 'cuda') else False def __enter__(self): self.start = self.time() @@ -196,7 +197,7 @@ def __exit__(self, type, value, traceback): def time(self): if self.cuda: - torch.cuda.synchronize() + torch.cuda.synchronize(self.device) return time.time() diff --git a/val.py b/val.py index b3d05f4305ce..1a4219c38962 100644 --- a/val.py +++ b/val.py @@ -191,7 +191,7 @@ def run( class_map = coco80_to_coco91_class() if is_coco else list(range(1000)) s = ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'P', 'R', 'mAP50', 'mAP50-95') tp, fp, p, r, f1, mp, mr, map50, ap50, map = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 - dt = Profile(), Profile(), Profile() # profiling times + dt = Profile(device=device), Profile(device=device), Profile(device=device) # profiling times loss = torch.zeros(3, device=device) jdict, stats, ap, ap_class = [], [], [], [] callbacks.run('on_val_start')