Skip to content

Commit

Permalink
Merge pull request ultralytics#11 from Laughing-q/instance_seg
Browse files Browse the repository at this point in the history
update bias init&&update obj loss
  • Loading branch information
AyushExel authored Aug 3, 2022
2 parents ab93a4e + 29e433b commit 483d13e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 10 deletions.
8 changes: 4 additions & 4 deletions evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
import torch
import torch.nn.functional as F
from PIL import Image
# import pycocotools.mask as mask_util
import pycocotools.mask as mask_util
from tqdm import tqdm

from models.experimental import attempt_load
from seg_dataloaders import create_dataloader
from utils.general import (box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, )
from utils.general import (check_dataset, check_img_size, check_suffix, )
from utils.general import (check_dataset, check_img_size, check_suffix)
from utils.general import (coco80_to_coco91_class, increment_path, colorstr, )
from utils.plots import output_to_target, plot_images_boxes_and_masks
from utils.seg_metrics import ap_per_class, ap_per_class_box_and_mask, ConfusionMatrix
from utils.segment import (non_max_suppression_masks, mask_iou, process_mask, process_mask_upsample, scale_masks, )
from utils.torch_utils import select_device, time_sync
from utils.torch_utils import select_device, time_sync, de_parallel


def save_one_txt(predn, save_conf, shape, file):
Expand Down Expand Up @@ -304,7 +304,7 @@ def inference(self, model, img, targets, masks=None, compute_loss=None):
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(self.device) # to pixels
t3 = time_sync()
out = self.nms(prediction=out, conf_thres=self.conf_thres, iou_thres=self.iou_thres, multi_label=True,
agnostic=self.single_cls, )
agnostic=self.single_cls, mask_dim=de_parallel(model).model[-1].mask_dim)
self.dt[2] += time_sync() - t3
return out, train_out

Expand Down
2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is
for mi, s in zip(m.m, m.stride): # from
b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
b.data[:, 5+m.mask_dim:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)

def _print_biases(self):
Expand Down
2 changes: 2 additions & 0 deletions train_instseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
g[0].append(v.weight)

# hyp['lr0'] = hyp['lr0'] / batch_size * 128
# hyp['warmup_bias_lr'] = 0.01
if opt.optimizer == 'Adam':
optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
elif opt.optimizer == 'AdamW':
Expand Down
18 changes: 13 additions & 5 deletions utils/seg_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def loss_segment(self, preds, targets, masks):
if self.sort_obj_iou:
sort_id = torch.argsort(score_iou)
b, a, gj, gi, score_iou = (b[sort_id], a[sort_id], gj[sort_id], gi[sort_id], score_iou[sort_id],)
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * score_iou # iou ratio
tobj[b, a, gj, gi] = 0.5 * ((1.0 - self.gr) + self.gr * score_iou) # iou ratio

# Classification
if self.nc > 1: # cls loss (only if multiple classes)
Expand Down Expand Up @@ -170,7 +170,13 @@ def loss_segment(self, preds, targets, masks):
psi = ps[index][:, 5: self.nm]
proto = proto_out[bi]

batch_lseg += self.single_mask_loss(mask_gti, psi, proto, mxyxy, mw, mh)
one_lseg, iou = self.single_mask_loss(mask_gti, psi, proto, mxyxy, mw, mh)
batch_lseg += one_lseg

# update tobj
iou = iou.detach().clamp(0).type(tobj.dtype)
tobj[b[index], a[index], gj[index], gi[index]] += 0.5 * iou[0]

lseg += batch_lseg / len(b.unique())

obji = self.BCEobj(pi[..., 4], tobj)
Expand All @@ -193,10 +199,12 @@ def single_mask_loss(self, gt_mask, pred, proto, xyxy, w, h):
"""mask loss of single pic."""
# (80, 80, 32) @ (32, n) -> (80, 80, n)
pred_mask = proto @ pred.tanh().T
# lseg_iou = self.mask_loss(pred_mask, gt_mask, xyxy)
iou = self.mask_loss(pred_mask, gt_mask, xyxy, return_iou=True)
lseg = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
lseg = crop(lseg, xyxy)
lseg = lseg.mean(dim=(0, 1)) / w / h
return lseg.mean()
return lseg.mean(), iou# + lseg_iou.mean()

def build_targets(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
Expand Down Expand Up @@ -334,7 +342,7 @@ class MaskIOULoss(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, pred_mask, gt_mask, mxyxy=None):
def forward(self, pred_mask, gt_mask, mxyxy=None, return_iou=False):
"""
Args:
pred_mask (torch.Tensor): prediction of masks, (80/160, 80/160, n)
Expand All @@ -349,7 +357,7 @@ def forward(self, pred_mask, gt_mask, mxyxy=None):
pred_mask = pred_mask.permute(2, 0, 1).view(n, -1)
gt_mask = gt_mask.permute(2, 0, 1).view(n, -1)
iou = masks_iou(pred_mask, gt_mask)
return 1.0 - iou
return iou if return_iou else (1.0 - iou)


import math
Expand Down

0 comments on commit 483d13e

Please sign in to comment.