Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Support Rotated YOLOX (CVPR'21) #409

Draft
wants to merge 19 commits into
base: dev
Choose a base branch
from
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ A summary can be found in the [Model Zoo](docs/en/model_zoo.md) page.
- [x] [S<sup>2</sup>A-Net](configs/s2anet/README.md) (TGRS'2021)
- [x] [ReDet](configs/redet/README.md) (CVPR'2021)
- [x] [Beyond Bounding-Box](configs/cfa/README.md) (CVPR'2021)
- [x] [Rotated YOLOX](configs/rotated_yolox/README.md) (CVPR'2021)
- [x] [Oriented R-CNN](configs/oriented_rcnn/README.md) (ICCV'2021)
- [x] [GWD](configs/gwd/README.md) (ICML'2021)
- [x] [KLD](configs/kld/README.md) (NeurIPS'2021)
Expand Down
48 changes: 48 additions & 0 deletions configs/rotated_yolox/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Rotated YOLOX

> [YOLOX: Exceeding YOLO Series in 2021](https://arxiv.org/abs/2107.08430)

<!-- [ALGORITHM] -->

## Abstract

In this report, we present some experienced improvements to YOLO series, forming a new high-performance detector --
YOLOX. We switch the YOLO detector to an anchor-free manner and conduct other advanced detection techniques, i.e., a
decoupled head and the leading label assignment strategy SimOTA to achieve state-of-the-art results across a large scale
range of models: For YOLO-Nano with only 0.91M parameters and 1.08G FLOPs, we get 25.3% AP on COCO, surpassing NanoDet
by 1.8% AP; for YOLOv3, one of the most widely used detectors in industry, we boost it to 47.3% AP on COCO,
outperforming the current best practice by 3.0% AP; for YOLOX-L with roughly the same amount of parameters as
YOLOv4-CSP, YOLOv5-L, we achieve 50.0% AP on COCO at a speed of 68.9 FPS on Tesla V100, exceeding YOLOv5-L by 1.8% AP.
Further, we won the 1st Place on Streaming Perception Challenge (Workshop on Autonomous Driving at CVPR 2021) using a
single YOLOX-L model. We hope this report can provide useful experience for developers and researchers in practical
scenes, and we also provide deploy versions with ONNX, TensorRT, NCNN, and Openvino supported.

<div align=center>
<img src="https://user-images.githubusercontent.com/40661020/144001736-9fb303dd-eac7-46b0-ad45-214cfa51e928.png"/>
</div>

## Results and Models

| Backbone | Bbox Loss Type | Size | mAP | FPS | Config | Download |
| :-------------: | :------------------: | :---------: | :---: | :--: | :-----------------------------------------------: | :------: |
| Rotated YOLOX-s | Rotated IoU | (1024,1024) | 74.36 | 53.1 | [config](./rotated_yolox_s_300e_dota_le90.py) | - |
| Rotated YOLOX-s | Horizontal IoU + CSL | (1024,1024) | 74.71 | 46.8 | [config](./rotated_yolox_s_csl_300e_dota_le90.py) | - |
| Rotated YOLOX-s | KLD | (1024,1024) | 75.23 | 53.0 | [config](./rotated_yolox_s_kld_300e_dota_le90.py) | - |

**Note**:

- Rotated YOLOX with KLD Loss is unstable during training, which will lead to nan and then cause CUDA Error.
- Compared with original YOLOX in mmdet, Rotated YOLOX enable `grad_clip` to prevent nan at training process.
- All models are trained with batch size 8 on one GPU.
- FPS and speed are tested on a single RTX3090.

## Citation

```latex
@article{yolox2021,
title={{YOLOX}: Exceeding YOLO Series in 2021},
author={Ge, Zheng and Liu, Songtao and Wang, Feng and Li, Zeming and Sun, Jian},
journal={arXiv preprint arXiv:2107.08430},
year={2021}
}
```
197 changes: 197 additions & 0 deletions configs/rotated_yolox/rotated_yolox_s_300e_dota_le90.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
_base_ = ['../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py']

angle_version = 'le90'
img_scale = (1024, 1024) # height, width

# model settings
model = dict(
type='RotatedYOLOX',
input_size=img_scale,
random_size_range=(25, 35),
random_size_interval=10,
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
neck=dict(
type='YOLOXPAFPN',
in_channels=[128, 256, 512],
out_channels=128,
num_csp_blocks=1),
bbox_head=dict(
type='RotatedYOLOXHead',
num_classes=15,
in_channels=128,
feat_channels=128,
separate_angle=False,
with_angle_l1=True,
angle_norm_factor=5,
edge_swap=angle_version,
loss_bbox=dict(
type='RotatedIoULoss',
mode='square',
eps=1e-16,
reduction='sum',
loss_weight=5.0),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
loss_obj=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0),
),
train_cfg=dict(assigner=dict(type='RSimOTAAssigner', center_radius=2.5)),
test_cfg=dict(
score_thr=0.01, nms=dict(type='nms_rotated', iou_threshold=0.10)))

# dataset settings
dataset_type = 'DOTADataset'
data_root = '/datasets/dota_mmrotate_ss/'

train_pipeline = [
dict(type='RMosaic', img_scale=img_scale, pad_val=114.0),
dict(
type='PolyRandomAffine',
version=angle_version,
scaling_ratio_range=(0.1, 2),
bbox_clip_border=False,
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='PolyMixUp',
version=angle_version,
bbox_clip_border=False,
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(type='YOLOXHSVRandomAug'),
dict(
type='RRandomFlip',
flip_ratio=[0.25, 0.25, 0.25],
direction=['horizontal', 'vertical', 'diagonal'],
version=angle_version),
# According to the official implementation, multi-scale
# training is not considered here but in the
# 'mmrotate/models/detectors/rotated_yolox.py.'
dict(type='RResize', img_scale=img_scale),
dict(
type='Pad',
pad_to_square=True,
# If the image is three-channel, the pad value needs
# to be set separately for each channel.
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(
type='FilterRotatedAnnotations',
min_gt_bbox_wh=(1, 1),
keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
train_dataset = dict(
type='MultiImageMixDataset',
dataset=dict(
type=dataset_type,
version=angle_version,
ann_file=data_root + 'trainval/annfiles/',
img_prefix=data_root + 'trainval/images/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
],
filter_empty_gt=False,
),
pipeline=train_pipeline)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
flip=False,
transforms=[
dict(type='RResize'),
dict(
type='Pad',
pad_to_square=True,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img'])
])
]

data = dict(
samples_per_gpu=8,
workers_per_gpu=8,
persistent_workers=True,
train=train_dataset,
val=dict(
type=dataset_type,
version=angle_version,
ann_file=data_root + 'trainval/annfiles/',
img_prefix=data_root + 'trainval/images/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
version=angle_version,
ann_file=data_root + 'test/images/',
img_prefix=data_root + 'test/images/',
pipeline=test_pipeline))

# optimizer
optimizer = dict(
type='SGD',
lr=0.01 / 8,
momentum=0.9,
weight_decay=5e-4,
nesterov=True,
paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))

max_epochs = 300
num_last_epochs = 15
resume_from = None
interval = 10

# learning policy
lr_config = dict(
_delete_=True,
policy='YOLOX',
warmup='exp',
by_epoch=False,
warmup_by_epoch=True,
warmup_ratio=1,
warmup_iters=5, # 5 epoch
num_last_epochs=num_last_epochs,
min_lr_ratio=0.05)

runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)

custom_hooks = [
dict(
type='YOLOXModeSwitchHook',
num_last_epochs=num_last_epochs,
skip_type_keys=('RMosaic', 'PolyRandomAffine', 'PolyMixUp'),
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=num_last_epochs,
interval=interval,
priority=48),
dict(
type='ExpMomentumEMAHook',
resume_from=resume_from,
momentum=0.0001,
priority=49)
]
checkpoint_config = dict(interval=interval)
evaluation = dict(
save_best='auto',
# The evaluation interval is 'interval' when running epoch is
# less than ‘max_epochs - num_last_epochs’.
# The evaluation interval is 1 when running epoch is greater than
# or equal to ‘max_epochs - num_last_epochs’.
interval=interval,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)],
metric='mAP')
log_config = dict(interval=50)
26 changes: 26 additions & 0 deletions configs/rotated_yolox/rotated_yolox_s_csl_300e_dota_le90.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_base_ = './rotated_yolox_s_300e_dota_le90.py'

# model settings
model = dict(
bbox_head=dict(
separate_angle=True,
with_angle_l1=False,
loss_bbox=dict(
_delete_=True,
type='IoULoss',
mode='square',
eps=1e-16,
reduction='sum',
loss_weight=5.0),
angle_coder=dict(
type='CSLCoder',
angle_version='le90',
omega=1,
window='gaussian',
radius=2),
loss_angle=dict(
type='SmoothFocalLoss',
gamma=2.0,
alpha=0.25,
reduction='sum',
loss_weight=1.0)))
3 changes: 3 additions & 0 deletions configs/rotated_yolox/rotated_yolox_s_fp16_300e_dota_le90.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = './rotated_yolox_s_300e_dota_le90.py'

fp16 = dict(loss_scale='dynamic')
14 changes: 14 additions & 0 deletions configs/rotated_yolox/rotated_yolox_s_kld_300e_dota_le90.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_base_ = './rotated_yolox_s_300e_dota_le90.py'

# model settings
model = dict(
bbox_head=dict(
loss_bbox=dict(
_delete_=True,
type='GDLoss',
loss_type='kld',
fun='log1p',
tau=1,
sqrt=True,
reduction='sum',
loss_weight=27.5)))
10 changes: 5 additions & 5 deletions mmrotate/core/bbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
GVFixCoder, GVRatioCoder, MidpointOffsetCoder)
from .iou_calculators import RBboxOverlaps2D, rbbox_overlaps
from .samplers import RRandomSampler
from .transforms import (bbox_mapping_back, gaussian2bbox, gt2gaussian,
hbb2obb, norm_angle, obb2hbb, obb2poly, obb2poly_np,
obb2xyxy, poly2obb, poly2obb_np, rbbox2result,
rbbox2roi)
from .transforms import (bbox_mapping_back, find_inside_polygons,
gaussian2bbox, gt2gaussian, hbb2obb, norm_angle,
obb2hbb, obb2poly, obb2poly_np, obb2xyxy, poly2obb,
poly2obb_np, rbbox2result, rbbox2roi)
from .utils import GaussianMixture

__all__ = [
Expand All @@ -20,5 +20,5 @@
'GVRatioCoder', 'ConvexAssigner', 'MaxConvexIoUAssigner', 'SASAssigner',
'ATSSKldAssigner', 'gaussian2bbox', 'gt2gaussian', 'GaussianMixture',
'build_assigner', 'build_bbox_coder', 'build_sampler', 'bbox_mapping_back',
'CSLCoder', 'ATSSObbAssigner'
'CSLCoder', 'ATSSObbAssigner', 'find_inside_polygons'
]
3 changes: 2 additions & 1 deletion mmrotate/core/bbox/assigners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from .atss_obb_assigner import ATSSObbAssigner
from .convex_assigner import ConvexAssigner
from .max_convex_iou_assigner import MaxConvexIoUAssigner
from .r_sim_ota_assinger import RSimOTAAssigner
from .sas_assigner import SASAssigner

__all__ = [
'ConvexAssigner', 'MaxConvexIoUAssigner', 'SASAssigner', 'ATSSKldAssigner',
'ATSSObbAssigner'
'ATSSObbAssigner', 'RSimOTAAssigner'
]
Loading