Skip to content

Commit

Permalink
Refactor load_from_yolov5 in YOLOv5 (#179)
Browse files Browse the repository at this point in the history
* Rename YOLOModule to YOLOv5

* Adopt classmethod to implement load_from_yolov5

* Fix unit-test for load_from_yolov5

* Update What-it-is and badges

* Update README.md
  • Loading branch information
zhiqwang committed Sep 30, 2021
1 parent 750e3b4 commit 00f6e13
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 38 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ lightning_logs
*.ipynb
runs
*.pt
*.onnx
yolort/version.py
.idea
*.ttf
Expand Down
42 changes: 31 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,33 @@

**YOLOv5 Runtime Stack**

[![CI testing](https://github.com/zhiqwang/yolov5-rt-stack/workflows/CI%20testing/badge.svg)](https://github.com/zhiqwang/yolov5-rt-stack/actions?query=workflow%3A%22CI+testing%22)
[![codecov](https://codecov.io/gh/zhiqwang/yolov5-rt-stack/branch/master/graph/badge.svg?token=1GX96EA72Y)](https://codecov.io/gh/zhiqwang/yolov5-rt-stack)
______________________________________________________________________

[Documentation](#%EF%B8%8F-usage)
[Installation Instructions](#installation-and-inference-examples)
[Deployment](#-deployment)
[Contributing](#-contributing)
[Reporting Issues](https://github.com/zhiqwang/yolov5-rt-stack/issues/new?assignees=&labels=&template=bug-report.yml)

______________________________________________________________________

[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/yolort)](https://pypi.org/project/yolort/)
[![PyPI version](https://badge.fury.io/py/yolort.svg)](https://badge.fury.io/py/yolort)
[![PyPI downloads](https://static.pepy.tech/personalized-badge/yolort?period=total&units=international_system&left_color=grey&right_color=blue&left_text=pypi%20downloads)](https://pepy.tech/project/yolort)
[![Github downloads](https://img.shields.io/github/downloads/zhiqwang/yolov5-rt-stack/total?color=blue&label=downloads&logo=github&logoColor=lightgrey)](https://img.shields.io/github/downloads/zhiqwang/yolov5-rt-stack/total?color=blue&label=Downloads&logo=github&logoColor=lightgrey)

[![CI testing](https://github.com/zhiqwang/yolov5-rt-stack/workflows/CI%20testing/badge.svg)](https://github.com/zhiqwang/yolov5-rt-stack/actions?query=workflow%3A%22CI+testing%22)
[![codecov](https://codecov.io/gh/zhiqwang/yolov5-rt-stack/branch/master/graph/badge.svg?token=1GX96EA72Y)](https://codecov.io/gh/zhiqwang/yolov5-rt-stack)
[![license](https://img.shields.io/github/license/zhiqwang/yolov5-rt-stack?color=brightgreen)](LICENSE)
[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/yolort/shared_invite/zt-mqwc7235-940aAh8IaKYeWclrJx10SA)

---
______________________________________________________________________

</div>

**What it is.** Yet another implementation of Ultralytics's [YOLOv5](https://github.com/ultralytics/yolov5), and with modules refactoring to adapt to different deployment scenarios such as `libtorch`, `onnxruntime`, `tvm` and so on.
## 🤗 Introduction

**What it is.** Yet another implementation of Ultralytics's [YOLOv5](https://github.com/ultralytics/yolov5). `yolort` aims to make the training and inference of the object detection integrate more seamlessly together. `yolort` now adopts the same model structure as the official YOLOv5. The significant difference is that we adopt the dynamic shape mechanism, and within this, we can embed both pre-processing (`letterbox`) and post-processing (`nms`) into the model graph, which simplifies the deployment strategy. In this sense, `yolort` makes it possible to be deployed more friendly on `LibTorch`, `ONNXRuntime`, `TVM` and so on.

**About the code.** Follow the design principle of [detr](https://github.com/facebookresearch/detr):

Expand All @@ -32,8 +47,7 @@
- *Nov. 21, 2020*. Add graph visualization tools.
- *Nov. 17, 2020*. Support exporting to `ONNX`, and inferencing with `ONNXRuntime` Python interface.
- *Nov. 16, 2020*. Refactor YOLO modules and support *dynamic shape/batch* inference.
- *Nov. 4, 2020*. Add `TorchScript` C++ inference example.
- *Oct. 10, 2020*. Support inferencing with `LibTorch` C++ interface.
- *Nov. 4, 2020*. Add `LibTorch` C++ inference example.
- *Oct. 8, 2020*. Support exporting to `TorchScript` model.

## 🛠️ Usage
Expand Down Expand Up @@ -93,25 +107,31 @@ model = torch.hub.load('zhiqwang/yolov5-rt-stack', 'yolov5s', pretrained=True)

### Loading checkpoint from official yolov5

The module state of `yolort` has some differences comparing to `ultralytics/yolov5`. We can load ultralytics's trained model checkpoint with minor changes, and we have converted ultralytics's release [v3.1](https://github.com/ultralytics/yolov5/releases/tag/v3.1) and [v4.0](https://github.com/ultralytics/yolov5/releases/tag/v4.0). And now we supply an interface to load the checkpoint weights trained with `ultralytics/yolov5` as follows. See our [how-to-align-with-ultralytics-yolov5](http://github.com/zhiqwang/yolov5-rt-stack/blob/master/notebooks/how-to-align-with-ultralytics-yolov5.ipynb) notebook for more details.
The following is the interface for loading the checkpoint weights trained with `ultralytics/yolov5`. See our [how-to-align-with-ultralytics-yolov5](http://github.com/zhiqwang/yolov5-rt-stack/blob/master/notebooks/how-to-align-with-ultralytics-yolov5.ipynb) notebook for more details.

```python
from yolort.models import yolov5s
from yolort.models import YOLOv5

# Model
yolov5 = YOLOv5()

# 'yolov5s.pt' is downloaded from https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5s.pt
ckpt_path_from_ultralytics = 'yolov5s.pt'
model = yolov5s(score_thresh=0.25)
model.load_from_yolov5(ckpt_path_from_ultralytics)
model = yolov5.load_from_yolov5(ckpt_path_from_ultralytics, score_thresh=0.25)

model.eval()
img_path = 'test/assets/bus.jpg'
predictions = model.predict(img_path)
```

### Inference on `LibTorch` backend 🚀
## 🚀 Deployment

### Inference on `LibTorch` backend

We provide a [notebook](notebooks/inference-pytorch-export-libtorch.ipynb) to demonstrate how the model is transformed into `torchscript`. And we provide an [C++ example](./deployment/libtorch) of how to infer with the transformed `torchscript` model. For details see the [GitHub Actions](.github/workflows/ci_test.yml).

### Inference on `ONNXRuntime` backend

## 🎨 Model Graph Visualization

Now, `yolort` can draw the model graph directly, checkout our [model-graph-visualization](notebooks/model-graph-visualization.ipynb) notebook to see how to use and visualize the model graph.
Expand Down
9 changes: 5 additions & 4 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import Tensor

from yolort import models
from yolort.models import YOLOv5
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.transformer import darknet_tan_backbone
from yolort.models.anchor_utils import AnchorGenerator
Expand Down Expand Up @@ -290,10 +291,10 @@ def test_load_from_yolov5(arch, version, hash_prefix):
yolov5s_r40_url = f'https://github.com/ultralytics/yolov5/releases/download/{version}/{arch}.pt'
torch.hub.download_url_to_file(yolov5s_r40_url, yolov5s_r40_path, hash_prefix=hash_prefix)

model_load_from_yolov5 = models.__dict__[arch](score_thresh=0.25)
model_load_from_yolov5.load_from_yolov5(yolov5s_r40_path)
model_load_from_yolov5.eval()
out_from_yolov5 = model_load_from_yolov5.predict(img_path)
yolov5 = YOLOv5()
model_yolov5 = yolov5.load_from_yolov5(yolov5s_r40_path, score_thresh=0.25)
model_yolov5.eval()
out_from_yolov5 = model_yolov5.predict(img_path)
assert isinstance(out_from_yolov5[0], dict)
assert isinstance(out_from_yolov5[0]['boxes'], Tensor)
assert isinstance(out_from_yolov5[0]['labels'], Tensor)
Expand Down
18 changes: 10 additions & 8 deletions yolort/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from yolort.utils.activations import Hardswish, SiLU
from yolort.v5 import Conv

from .yolo_module import YOLOModule
from .yolo_module import YOLOv5

__all__ = ['YOLOv5', 'yolov5s', 'yolov5m', 'yolov5l', 'yolotr']


def yolov5s(upstream_version: str = 'r4.0', export_friendly: bool = False, **kwargs: Any):
Expand All @@ -17,9 +19,9 @@ def yolov5s(upstream_version: str = 'r4.0', export_friendly: bool = False, **kwa
Default: False.
"""
if upstream_version == 'r3.1':
model = YOLOModule(arch="yolov5_darknet_pan_s_r31", **kwargs)
model = YOLOv5(arch="yolov5_darknet_pan_s_r31", **kwargs)
elif upstream_version == 'r4.0':
model = YOLOModule(arch="yolov5_darknet_pan_s_r40", **kwargs)
model = YOLOv5(arch="yolov5_darknet_pan_s_r40", **kwargs)
else:
raise NotImplementedError("Currently only supports r3.1 and r4.0 versions")

Expand All @@ -38,9 +40,9 @@ def yolov5m(upstream_version: str = 'r4.0', export_friendly: bool = False, **kwa
Default: False.
"""
if upstream_version == 'r3.1':
model = YOLOModule(arch="yolov5_darknet_pan_m_r31", **kwargs)
model = YOLOv5(arch="yolov5_darknet_pan_m_r31", **kwargs)
elif upstream_version == 'r4.0':
model = YOLOModule(arch="yolov5_darknet_pan_m_r40", **kwargs)
model = YOLOv5(arch="yolov5_darknet_pan_m_r40", **kwargs)
else:
raise NotImplementedError("Currently only supports r3.1 and r4.0 versions")

Expand All @@ -59,9 +61,9 @@ def yolov5l(upstream_version: str = 'r4.0', export_friendly: bool = False, **kwa
Default: False.
"""
if upstream_version == 'r3.1':
model = YOLOModule(arch="yolov5_darknet_pan_l_r31", **kwargs)
model = YOLOv5(arch="yolov5_darknet_pan_l_r31", **kwargs)
elif upstream_version == 'r4.0':
model = YOLOModule(arch="yolov5_darknet_pan_l_r40", **kwargs)
model = YOLOv5(arch="yolov5_darknet_pan_l_r40", **kwargs)
else:
raise NotImplementedError("Currently only supports r3.1 and r4.0 versions")

Expand All @@ -80,7 +82,7 @@ def yolotr(upstream_version: str = 'r4.0', export_friendly: bool = False, **kwar
Default: False.
"""
if upstream_version == 'r4.0':
model = YOLOModule(arch="yolov5_darknet_tan_s_r40", **kwargs)
model = YOLOv5(arch="yolov5_darknet_tan_s_r40", **kwargs)
else:
raise NotImplementedError("Currently only supports r4.0 versions")

Expand Down
2 changes: 1 addition & 1 deletion yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

__all__ = ['YOLO', 'yolov5_darknet_pan_s_r31', 'yolov5_darknet_pan_m_r31', 'yolov5_darknet_pan_l_r31',
'yolov5_darknet_pan_s_r40', 'yolov5_darknet_pan_m_r40', 'yolov5_darknet_pan_l_r40',
'yolov5_darknet_tan_s_r40']
'yolov5_darknet_tan_s_r40', '_yolov5_darknet_pan']


class YOLO(nn.Module):
Expand Down
47 changes: 40 additions & 7 deletions yolort/models/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@

from yolort.data import COCOEvaluator, contains_any_tensor
from yolort.utils.update_module_state import ModuleStateUpdate
from yolort.v5 import load_yolov5_model
from yolort.v5 import load_yolov5_model, get_yolov5_size

from . import yolo
from .transform import YOLOTransform
from ._utils import _evaluate_iou

__all__ = ['YOLOModule']
__all__ = ['YOLOv5']


class YOLOModule(LightningModule):
class YOLOv5(LightningModule):
"""
PyTorch Lightning wrapper of `YOLO`
"""
Expand Down Expand Up @@ -118,7 +118,7 @@ def _forward_impl(

if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn("YOLOModule always returns a (Losses, Detections) tuple in scripting.")
warnings.warn("YOLOv5 always returns a (Losses, Detections) tuple in scripting.")
self._has_warned = True
return losses, detections
else:
Expand Down Expand Up @@ -272,15 +272,48 @@ def add_model_specific_args(parent_parser):
metavar='W', help='weight decay (default: 5e-4)')
return parser

def load_from_yolov5(self, checkpoint_path: str):
@classmethod
def load_from_yolov5(
cls,
checkpoint_path: str,
lr: float = 0.01,
size: Tuple[int, int] = (640, 640),
score_thresh: float = 0.25,
nms_thresh: float = 0.45,
version: str = 'r4.0',
):
"""
Load model state from the checkpoint trained by YOLOv5.
Args:
checkpoint_path (str): Path of the YOLOv5 checkpoint model.
"""
checkpoint_yolov5 = load_yolov5_model(checkpoint_path)
module_state_updater = ModuleStateUpdate(arch=self.arch, num_classes=self.num_classes)
num_classes = checkpoint_yolov5.yaml['nc']
anchor_grids = checkpoint_yolov5.yaml['anchors']
depth_multiple = checkpoint_yolov5.yaml['depth_multiple']
width_multiple = checkpoint_yolov5.yaml['width_multiple']

module_state_updater = ModuleStateUpdate(
arch=None,
depth_multiple=depth_multiple,
width_multiple=width_multiple,
version=version,
num_classes=num_classes,
)
module_state_updater.updating(checkpoint_yolov5)
state_dict = module_state_updater.model.state_dict()
self.model.load_state_dict(state_dict)
yolov5_size = get_yolov5_size(depth_multiple, width_multiple)
arch = f"yolov5_darknet_pan_{yolov5_size}_{version.replace('.', '')}"
yolov5_custom = cls(
lr=lr,
arch=arch,
size=size,
num_classes=num_classes,
anchor_grids=anchor_grids,
score_thresh=score_thresh,
nms_thresh=nms_thresh,
)

yolov5_custom.model.load_state_dict(state_dict)
return yolov5_custom
31 changes: 25 additions & 6 deletions yolort/utils/update_module_state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved.
from typing import Any
from typing import Optional

from functools import reduce
from torch import nn

from yolort.models import yolo
from yolort.v5 import load_yolov5_model
from yolort.v5 import load_yolov5_model, get_yolov5_size


ARCHITECTURE_MAPS = {
Expand All @@ -23,7 +23,7 @@ def update_module_state_from_ultralytics(
num_classes: int = 80,
set_fp16: bool = True,
verbose: bool = False,
**kwargs: Any,
**kwargs,
):
"""
Allows the user to specify a file to use when loading an ultralytics model for conversion.
Expand All @@ -47,7 +47,7 @@ def update_module_state_from_ultralytics(
key_arch = f'{arch}_{feature_fusion_type.lower()}_v4.0'
if key_arch not in ARCHITECTURE_MAPS:
raise NotImplementedError(
"Currently does't supports this architecture, "
"Currently does't support this architecture, "
"fell free to file an issue labeled enhancement to us"
)

Expand All @@ -71,7 +71,10 @@ class ModuleStateUpdate:
"""
def __init__(
self,
arch: str = 'yolov5_darknet_pan_s_r31',
arch: Optional[str] = 'yolov5_darknet_pan_s_r31',
depth_multiple: Optional[float] = None,
width_multiple: Optional[float] = None,
version: str = 'r4.0',
num_classes: int = 80,
inner_block_maps: dict = {'0': '9', '1': '10', '3': '13', '4': '14'},
layer_block_maps: dict = {'0': '17', '1': '18', '2': '20', '3': '21', '4': '23'},
Expand All @@ -84,7 +87,23 @@ def __init__(
self.head_ind = head_ind
self.head_name = head_name
# Set model
self.model = yolo.__dict__[arch](num_classes=num_classes)
if arch is not None:
model = yolo.__dict__[arch](num_classes=num_classes)
elif depth_multiple is not None and width_multiple is not None:
yolov5_size = get_yolov5_size(depth_multiple, width_multiple)
backbone_name = f"darknet_{yolov5_size}_{version.replace('.', '_')}"
weights_name = f"yolov5_darknet_pan_{yolov5_size}_{version.replace('.', '')}_coco"
model = yolo._yolov5_darknet_pan(
backbone_name,
depth_multiple,
width_multiple,
version,
weights_name,
num_classes=num_classes,
)
else:
raise NotImplementedError("Currently either arch or multiples must be set.")
self.model = model

def updating(self, state_dict):
# Obtain module state
Expand Down
16 changes: 15 additions & 1 deletion yolort/v5/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .models.yolo import Model
from .utils import attempt_download, set_logging

__all__ = ['add_yolov5_context', 'load_yolov5_model']
__all__ = ['add_yolov5_context', 'load_yolov5_model', 'get_yolov5_size']


@contextlib.contextmanager
Expand All @@ -27,6 +27,20 @@ def add_yolov5_context():
sys.path.remove(path_ultralytics_yolov5)


def get_yolov5_size(depth_multiple, width_multiple):
if depth_multiple == 0.33 and width_multiple == 0.5:
return 's'
elif depth_multiple == 0.67 and width_multiple == 0.75:
return 'm'
elif depth_multiple == 1.0 and width_multiple == 1.0:
return 'l'
else:
raise NotImplementedError(
f"Currently does't support architecture with depth: {depth_multiple} "
f"and {width_multiple}, fell free to create a ticket labeled enhancement to us"
)


def load_yolov5_model(checkpoint_path: str, autoshape: bool = False, verbose: bool = True):
"""
Creates a specified YOLOv5 model
Expand Down

0 comments on commit 00f6e13

Please sign in to comment.