Skip to content

Commit

Permalink
Implement @torch.no_grad() decorator (ultralytics#3312)
Browse files Browse the repository at this point in the history
* `@torch.no_grad()` decorator

* Update detect.py

(cherry picked from commit 61ea23c)
  • Loading branch information
glenn-jocher authored and Lechtr committed May 24, 2021
1 parent 0eb3716 commit 5191519
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
12 changes: 6 additions & 6 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from utils.torch_utils import select_device, load_classifier, time_synchronized


@torch.no_grad()
def detect(opt):
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
save_img = not opt.nosave and not source.endswith('.txt') # save inference images
Expand Down Expand Up @@ -175,10 +176,9 @@ def detect(opt):
print(opt)
check_requirements(exclude=('tensorboard', 'pycocotools', 'thop'))

with torch.no_grad():
if opt.update: # update all models (to fix SourceChangeWarning)
for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
detect(opt=opt)
strip_optimizer(opt.weights)
else:
if opt.update: # update all models (to fix SourceChangeWarning)
for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
detect(opt=opt)
strip_optimizer(opt.weights)
else:
detect(opt=opt)
32 changes: 16 additions & 16 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from utils.torch_utils import select_device, time_synchronized


@torch.no_grad()
def test(data,
weights=None,
batch_size=32,
Expand Down Expand Up @@ -105,22 +106,21 @@ def test(data,
targets = targets.to(device)
nb, _, height, width = img.shape # batch size, channels, height, width

with torch.no_grad():
# Run model
t = time_synchronized()
out, train_out = model(img, augment=augment) # inference and training outputs
t0 += time_synchronized() - t

# Compute loss
if compute_loss:
loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls

# Run NMS
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t = time_synchronized()
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
t1 += time_synchronized() - t
# Run model
t = time_synchronized()
out, train_out = model(img, augment=augment) # inference and training outputs
t0 += time_synchronized() - t

# Compute loss
if compute_loss:
loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls

# Run NMS
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t = time_synchronized()
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
t1 += time_synchronized() - t

# Statistics per image
for si, pred in enumerate(out):
Expand Down

0 comments on commit 5191519

Please sign in to comment.