Skip to content

Commit

Permalink
Merge pull request #636 from Nioolek/v8_seg_inference
Browse files Browse the repository at this point in the history
[Feature] Support YOLOv8 Ins Segmentation Inference
  • Loading branch information
Nioolek authored Mar 11, 2023
2 parents 3118eef + e4cc2e3 commit c0d1468
Show file tree
Hide file tree
Showing 7 changed files with 661 additions and 22 deletions.
59 changes: 59 additions & 0 deletions configs/yolov8/ins_seg/yolov8_ins_s_syncbn_fast_8xb16-500e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
_base_ = '../yolov8_s_mask-refine_syncbn_fast_8xb16-500e_coco.py'

# Batch size of a single GPU during validation
val_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 8

batch_shapes_cfg = dict(
_delete_=True,
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=_base_.img_scale[0],
# The image scale of padding should be divided by pad_size_divisor
size_divisor=32,
# Additional paddings for pixel scale
extra_pad_ratio=0.5)

# Testing take a long time due to model_test_cfg.
# If you want to speed it up, you can increase score_thr
# or decraese nms_pre and max_per_img
model_test_cfg = dict(
multi_label=True,
nms_pre=30000,
min_bbox_size=0,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.7),
max_per_img=300,
mask_thr_binary=0.5,
# fast_test: Whether to use fast test methods. When set
# to False, the implementation here is the same as the
# official, with higher mAP. If set to True, mask will first
# be upsampled to origin image shape through Pytorch, and
# then use mask_thr_binary to determine which pixels belong
# to the object. If set to False, will first use
# mask_thr_binary to determine which pixels belong to the
# object , and then use opencv to upsample mask to origin
# image shape. Default to False.
fast_test=False)

# ===============================Unmodified in most cases====================
model = dict(
bbox_head=dict(
type='YOLOv8InsHead',
head_module=dict(
type='YOLOv8InsHeadModule', masks_channels=32,
protos_channels=256)),
test_cfg=model_test_cfg)

val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
dataset=dict(batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader

val_evaluator = dict(metric=['bbox', 'segm'])
test_evaluator = val_evaluator

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
25 changes: 17 additions & 8 deletions mmyolo/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,17 @@ def _resize_img(self, results: dict):
self.scale)

if ratio != 1:
# resize image according to the ratio
image = mmcv.imrescale(
# resize image according to the shape
image = mmcv.imresize(
img=image,
scale=ratio,
size=(int(original_w * ratio), int(original_h * ratio)),
interpolation='area' if ratio < 1 else 'bilinear',
backend=self.backend)

resized_h, resized_w = image.shape[:2]
scale_ratio = resized_h / original_h

scale_factor = (scale_ratio, scale_ratio)
scale_ratio_h = resized_h / original_h
scale_ratio_w = resized_w / original_w
scale_factor = (scale_ratio_w, scale_ratio_h)

results['img'] = image
results['img_shape'] = image.shape[:2]
Expand Down Expand Up @@ -212,7 +212,8 @@ def _resize_img(self, results: dict):
interpolation=self.interpolation,
backend=self.backend)

scale_factor = (ratio[1], ratio[0]) # mmcv scale factor is (w, h)
scale_factor = (no_pad_shape[1] / image_shape[1],
no_pad_shape[0] / image_shape[0])

if 'scale_factor' in results:
results['scale_factor_origin'] = results['scale_factor']
Expand Down Expand Up @@ -246,7 +247,15 @@ def _resize_img(self, results: dict):
if 'pad_param' in results:
results['pad_param_origin'] = results['pad_param'] * \
np.repeat(ratio, 2)
results['pad_param'] = np.array(padding_list, dtype=np.float32)

if 'gt_masks' in results:
results['pad_param'] = np.array(
[padding_h / 2, padding_h / 2, padding_w / 2, padding_w / 2],
dtype=np.float32)
else:
# We found in object detection, using padding list with
# int type can get higher mAP.
results['pad_param'] = np.array(padding_list, dtype=np.float32)

def _resize_masks(self, results: dict):
"""Resize masks with ``results['scale']``"""
Expand Down
3 changes: 2 additions & 1 deletion mmyolo/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .yolov6_head import YOLOv6Head, YOLOv6HeadModule
from .yolov7_head import YOLOv7Head, YOLOv7HeadModule, YOLOv7p6HeadModule
from .yolov8_head import YOLOv8Head, YOLOv8HeadModule
from .yolov8_ins_head import YOLOv8InsHead, YOLOv8InsHeadModule
from .yolox_head import YOLOXHead, YOLOXHeadModule

__all__ = [
Expand All @@ -16,5 +17,5 @@
'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule',
'YOLOv7HeadModule', 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule',
'RTMDetRotatedHead', 'RTMDetRotatedSepBNHeadModule', 'RTMDetInsSepBNHead',
'RTMDetInsSepBNHeadModule'
'RTMDetInsSepBNHeadModule', 'YOLOv8InsHead', 'YOLOv8InsHeadModule'
]
Loading

0 comments on commit c0d1468

Please sign in to comment.