Skip to content

Commit

Permalink
Remake ultralytics#7
Browse files Browse the repository at this point in the history
  • Loading branch information
manole-alexandru committed Mar 25, 2023
1 parent 620f1c3 commit 23d61c6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __call__(self, preds, targets, seg_masks): # predictions, targets

# return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()
# return total_loss, torch.cat((lbox, lobj, lcls, lseg)).detach()
return (lbox + lobj + lcls) * bs * 1, lseg * bs * 1, torch.cat((lbox, lobj, lcls, lseg)).detach()
return (lbox + lobj + lcls) * bs * 0, lseg * bs * 1, torch.cat((lbox, lobj, lcls, lseg)).detach()

def build_targets(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
Expand Down
2 changes: 1 addition & 1 deletion utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
g_seg[0].append(p) # weight (with decay)

det_optimizer = get_optimizer(name, lr, momentum, decay, g_det)
seg_optimizer = get_optimizer(name, lr * 5, momentum, decay, g_seg)
seg_optimizer = get_optimizer(name, lr, momentum, decay, g_seg)
return det_optimizer, seg_optimizer


Expand Down
14 changes: 9 additions & 5 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def compute_seg_iou(pred, target, n_classes=2):
# print(pred)

# Ignore IoU for background class ("0")
for cls in range(1, n_classes): # This goes from 1:n_classes-1 -> class "0" is ignored
for cls in range(0, n_classes): # This goes from 1:n_classes-1 -> class "0" is ignored
pred_inds = pred == cls
target_inds = target == cls
intersection = (pred_inds[target_inds]).long().sum().data.cpu() # Cast to long to prevent overflows
Expand Down Expand Up @@ -210,9 +210,10 @@ def run(
if isinstance(names, (list, tuple)): # old format
names = dict(enumerate(names))
class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
s = ('%22s' + '%11s' * 7) % ('Class', 'Images', 'Instances', 'P', 'R', 'mAP50', 'mAP50-95', 'Seg IoU')
s = ('%22s' + '%11s' * 8) % ('Class', 'Images', 'Instances', 'P', 'R', 'mAP50', 'mAP50-95', 'Seg mIoU', 'Rail IoU')
tp, fp, p, r, f1, mp, mr, map50, ap50, map = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
iou = 0.0
rail_iou = 0.0
dt = Profile(), Profile(), Profile() # profiling times
loss = torch.zeros(4, device=device)
jdict, stats, ap, ap_class = [], [], [], []
Expand All @@ -237,7 +238,8 @@ def run(

ious = compute_seg_iou(pred_mask, segs)
# print('\n------------ IoU: ', ious, '------------\n')
iou += ious[0]
iou += (ious[0] + ious[1]) / 2
rail_iou += ious[1]

# Loss
if compute_loss:
Expand Down Expand Up @@ -309,8 +311,10 @@ def run(
nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class

# Print results
pf = '%22s' + '%11i' * 2 + '%11.3g' * 5 # print format
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map, iou))
pf = '%22s' + '%11i' * 2 + '%11.3g' * 6 # print format
iou = iou / len(pbar)
rail_iou = rail_iou / len(pbar)
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map, iou, rail_iou))
if nt.sum() == 0:
LOGGER.warning(f'WARNING ⚠️ no labels found in {task} set, can not compute metrics without labels')

Expand Down

0 comments on commit 23d61c6

Please sign in to comment.