Skip to content

Commit

Permalink
Merge pull request #19 from Laughing-q/instance_seg
Browse files Browse the repository at this point in the history
fix accuracy issue
  • Loading branch information
AyushExel committed Aug 22, 2022
2 parents 261bec1 + cabb99d commit ebc0bcb
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def forward(self, x):
y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
xy = (y[..., 0:2] * 2. + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy.type_as(y), wh.type_as(y), y[..., 4:]), -1)
z.append(y.view(-1, self.na * ny * nx, self.no))
Expand Down
2 changes: 1 addition & 1 deletion segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def run(
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string

# Mask plotting ----------------------------------------------------------------------------------------
mcolors = [colors(int(cls)) for cls in range(len(det[:, 5]))]
mcolors = [colors(int(cls), True) for cls in det[:, 5]]
# NOTE: this plot method is faster, but the image might get blurred https://github.com/dbolya/yolact
img_masks = plot_masks(im[i], masks, mcolors) # image with masks shape(imh,imw,3)
img_masks = scale_masks(im.shape[2:], img_masks, im0.shape) # scale to original h, w
Expand Down
2 changes: 2 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,8 @@ def non_max_suppression(
continue
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
else:
x = x[x[:, 4].argsort(descending=True)] # sort by confidence

# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
Expand Down

0 comments on commit ebc0bcb

Please sign in to comment.