Skip to content

Commit

Permalink
Resolve precommit utils/segment/general
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Aug 20, 2022
1 parent 74eabbf commit a752e67
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions utils/segment/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F
import torchvision

from ..general import xywh2xyxy
from ..general import LOGGER, xywh2xyxy
from ..metrics import box_iou


Expand Down Expand Up @@ -53,11 +53,11 @@ def non_max_suppression_masks(

# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5] # box
lb = labels[xi]
v = torch.zeros((len(lb), nc + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)

# If none remain process next image
Expand Down Expand Up @@ -101,7 +101,7 @@ def non_max_suppression_masks(
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
Expand All @@ -111,7 +111,7 @@ def non_max_suppression_masks(

output[xi] = x[i]
if (time.time() - t) > time_limit:
print(f"WARNING: NMS time limit {time_limit}s exceeded")
LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
break # time limit exceeded

return output
Expand Down

0 comments on commit a752e67

Please sign in to comment.