Skip to content

Commit

Permalink
Update selectable device Profile (#12353)
Browse files Browse the repository at this point in the history
* Update selectable device Profile

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Yakuho and pre-commit-ci[bot] committed Jan 3, 2024
1 parent bd1a829 commit 69b0faf
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion classify/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def run(

# Run inference
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
for path, im, im0s, vid_cap, s in dataset:
with dt[0]:
im = torch.Tensor(im).to(model.device)
Expand Down
2 changes: 1 addition & 1 deletion classify/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def run(
workers=workers)

model.eval()
pred, targets, loss, dt = [], [], 0, (Profile(), Profile(), Profile())
pred, targets, loss, dt = [], [], 0, (Profile(device=device), Profile(device=device), Profile(device=device))
n = len(dataloader) # number of batches
action = 'validating' if dataloader.dataset.root.stem == 'val' else 'testing'
desc = f'{pbar.desc[:-36]}{action:>36}' if pbar else f'{action}'
Expand Down
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def run(

# Run inference
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
for path, im, im0s, vid_cap, s in dataset:
with dt[0]:
im = torch.from_numpy(im).to(model.device)
Expand Down
2 changes: 1 addition & 1 deletion segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def run(

# Run inference
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
for path, im, im0s, vid_cap, s in dataset:
with dt[0]:
im = torch.from_numpy(im).to(model.device)
Expand Down
2 changes: 1 addition & 1 deletion segment/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def run(
class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
s = ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P', 'R',
'mAP50', 'mAP50-95)')
dt = Profile(), Profile(), Profile()
dt = Profile(device=device), Profile(device=device), Profile(device=device)
metrics = Metrics()
loss = torch.zeros(4, device=device)
jdict, stats = [], []
Expand Down
7 changes: 4 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):

class Profile(contextlib.ContextDecorator):
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
def __init__(self, t=0.0):
def __init__(self, t=0.0, device: torch.device = None):
self.t = t
self.cuda = torch.cuda.is_available()
self.device = device
self.cuda = True if (device and str(device)[:4] == 'cuda') else False

def __enter__(self):
self.start = self.time()
Expand All @@ -196,7 +197,7 @@ def __exit__(self, type, value, traceback):

def time(self):
if self.cuda:
torch.cuda.synchronize()
torch.cuda.synchronize(self.device)
return time.time()


Expand Down
2 changes: 1 addition & 1 deletion val.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def run(
class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
s = ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'P', 'R', 'mAP50', 'mAP50-95')
tp, fp, p, r, f1, mp, mr, map50, ap50, map = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
dt = Profile(), Profile(), Profile() # profiling times
dt = Profile(device=device), Profile(device=device), Profile(device=device) # profiling times
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class = [], [], [], []
callbacks.run('on_val_start')
Expand Down

0 comments on commit 69b0faf

Please sign in to comment.