Skip to content

Commit

Permalink
Move convert_yolov5_checkpoint into _checkpoint.py (#374)
Browse files Browse the repository at this point in the history
* Move convert_yolov5_checkpoint into _checkpoint.py

* Add convert_yolov5_checkpoint into _checkpoint.py

* Fix path in tools/yolov5_to_yolort.py

* Rename to convert_yolov5_to_yolort.py

* Fix lint

* Minor updates

* Workaround for the moment by adding explicit dependecy of jinja2
  • Loading branch information
zhiqwang authored Mar 25, 2022
1 parent 5942ce8 commit a6a08dd
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 55 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ jupyterlab
ipython
sphinx
sphinx-material
jinja2==3.0.3
nbsphinx
PyYAML>=5.3.1
26 changes: 8 additions & 18 deletions tools/yolov5_to_yolort.py → tools/convert_yolov5_to_yolort.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,24 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
# Copyright (c) 2021, yolort team. All rights reserved.

import argparse
from pathlib import Path

from yolort.utils import convert_yolov5_to_yolort
from yolort.models._checkpoint import convert_yolov5_checkpoint


def get_parser():
parser = argparse.ArgumentParser("Convert checkpoints from yolov5 to yolort", add_help=True)

parser.add_argument(
"--checkpoint_path",
type=str,
required=True,
help="Path of the checkpoint weights",
)
parser.add_argument("--checkpoint_path", type=str, required=True, help="Path of the checkpoints")
parser.add_argument(
"--version",
type=str,
default="r6.0",
help="Upstream version released by the ultralytics/yolov5, Possible "
"values are ['r3.1', 'r4.0', 'r6.0']. Default: 'r6.0'.",
choices=["r3.1", "r4.0", "r6.0"],
help="Upstream version released by the ultralytics/yolov5 (default: 'r6.0').",
)
# Dataset Configuration
parser.add_argument(
"--image_path",
type=str,
default="./test/assets/zidane.jpg",
help="Path of the test image",
)

parser.add_argument("--image_path", type=str, default="zidane.jpg", help="Path of the test image")
parser.add_argument("--output_path", type=str, default=None, help="Path where to save")
return parser

Expand All @@ -45,7 +35,7 @@ def cli_main():
output_path = Path(args.output_path)
output_path.mkdir(parents=True, exist_ok=True)

convert_yolov5_to_yolort(checkpoint_path, output_path, version=args.version)
convert_yolov5_checkpoint(checkpoint_path, output_path, version=args.version)


if __name__ == "__main__":
Expand Down
31 changes: 30 additions & 1 deletion yolort/models/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from functools import reduce
from typing import Dict, List, Optional

import torch
from torch import nn
from yolort.v5 import get_yolov5_size, load_yolov5_model

from .backbone_utils import darknet_pan_backbone
from .box_head import YOLOHead

__all__ = ["load_from_ultralytics"]
__all__ = ["convert_yolov5_checkpoint", "load_from_ultralytics"]


def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
Expand Down Expand Up @@ -93,6 +94,34 @@ def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
}


def convert_yolov5_checkpoint(
checkpoint_path: str,
output_path: str,
version: str = "r6.0",
prefix: str = "yolov5_darknet_pan",
postfix: str = "custom.pt",
):
"""
Convert model checkpoint trained with ultralytics/yolov5 to yolort.
Args:
checkpoint_path (str): Path of the YOLOv5 checkpoint model.
output_path (str): Path of the converted yolort checkpoint model.
version (str): upstream version released by the ultralytics/yolov5, Possible
values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0".
prefix (str): The prefix string of the saved model. Default: "yolov5_darknet_pan".
postfix (str): The postfix string of the saved model. Default: "custom.pt".
"""

model_info = load_from_ultralytics(checkpoint_path, version=version)
model_state_dict = model_info["state_dict"]

size = model_info["size"]
use_p6 = "6" if model_info["use_p6"] else ""
output_postfix = f"{prefix}_{size}{use_p6}_{version.replace('.', '')}_{postfix}"
torch.save(model_state_dict, output_path / output_postfix)


class ModelWrapper(nn.Module):
def __init__(self, backbone, head):
super().__init__()
Expand Down
4 changes: 1 addition & 3 deletions yolort/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2020, yolort team. All rights reserved.

from typing import Any, Callable, Dict, Mapping, Sequence, Type, Union
from typing import Any, Dict, Callable, Mapping, Sequence, Type, Union

from torch import Tensor

Expand All @@ -12,7 +12,6 @@
from .dependency import check_version
from .hooks import FeatureExtractor
from .image_utils import cv2_imshow, get_image_from_url, read_image_to_tensor
from .update_module_state import convert_yolov5_to_yolort
from .visualizer import Visualizer


Expand All @@ -22,7 +21,6 @@
"cv2_imshow",
"get_image_from_url",
"get_callable_dict",
"convert_yolov5_to_yolort",
"load_state_dict_from_url",
"read_image_to_tensor",
"FeatureExtractor",
Expand Down
33 changes: 0 additions & 33 deletions yolort/utils/update_module_state.py

This file was deleted.

0 comments on commit a6a08dd

Please sign in to comment.