Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(model): support hub load #1189

Merged
merged 2 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

"""
Usage example:
import torch
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
"""
dependencies = ["torch"]

from yolox.models import ( # isort:skip # noqa: F401, E402
yolox_tiny,
yolox_nano,
yolox_s,
yolox_m,
yolox_l,
yolox_x,
yolov3,
)
1 change: 1 addition & 0 deletions yolox/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.

from .build import *
from .darknet import CSPDarknet, Darknet
from .losses import IOUloss
from .yolo_fpn import YOLOFPN
Expand Down
91 changes: 91 additions & 0 deletions yolox/models/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

import torch
from torch import nn
from torch.hub import load_state_dict_from_url

__all__ = [
"create_yolox_model",
"yolox_nano",
"yolox_tiny",
"yolox_s",
"yolox_m",
"yolox_l",
"yolox_x",
"yolov3",
]

_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
_CKPT_FULL_PATH = {
"yolox-nano": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_nano.pth",
"yolox-tiny": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_tiny.pth",
"yolox-s": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_s.pth",
"yolox-m": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_m.pth",
"yolox-l": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_l.pth",
"yolox-x": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_x.pth",
"yolov3": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_darknet.pth",
}


def create_yolox_model(
name: str, pretrained: bool = True, num_classes: int = 80, device=None
) -> nn.Module:
"""creates and loads a YOLOX model

Args:
name (str): name of model. for example, "yolox-s", "yolox-tiny".
pretrained (bool): load pretrained weights into the model. Default to True.
num_classes (int): number of model classes. Defalut to 80.
device (str): default device to for model. Defalut to None.

Returns:
YOLOX model (nn.Module)
"""
from yolox.exp import get_exp, Exp

if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

assert name in _CKPT_FULL_PATH, f"user should use one of value in {_CKPT_FULL_PATH.keys()}"
exp: Exp = get_exp(exp_name=name)
exp.num_classes = num_classes
yolox_model = exp.get_model()
if pretrained and num_classes == 80:
weights_url = _CKPT_FULL_PATH[name]
ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
if "model" in ckpt:
ckpt = ckpt["model"]
yolox_model.load_state_dict(ckpt)

yolox_model.to(device)
return yolox_model


def yolox_nano(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-nano", pretrained, num_classes, device)


def yolox_tiny(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)


def yolox_s(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-s", pretrained, num_classes, device)


def yolox_m(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-m", pretrained, num_classes, device)


def yolox_l(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-l", pretrained, num_classes, device)


def yolox_x(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-x", pretrained, num_classes, device)


def yolov3(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
12 changes: 8 additions & 4 deletions yolox/models/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from yolox.utils import bboxes_iou
from yolox.utils import bboxes_iou, meshgrid

from .losses import IOUloss
from .network_blocks import BaseConv, DWConv
Expand Down Expand Up @@ -220,7 +220,7 @@ def get_output_and_grid(self, output, k, stride, dtype):
n_ch = 5 + self.num_classes
hsize, wsize = output.shape[-2:]
if grid.shape[2:4] != output.shape[2:4]:
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
self.grids[k] = grid

Expand All @@ -237,7 +237,7 @@ def decode_outputs(self, outputs, dtype):
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides):
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
Expand Down Expand Up @@ -321,7 +321,11 @@ def get_losses(
labels,
imgs,
)
except RuntimeError:
except RuntimeError as e:
# TODO: the string might change, consider a better way
if "CUDA out of memory. " not in str(e):
raise # RuntimeError might not caused by CUDA OOM

logger.error(
"OOM RuntimeError is raised due to the huge memory cost during label assignment. \
CPU mode is applied in this batch. If you want to avoid this issue, \
Expand Down
1 change: 1 addition & 0 deletions yolox/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .allreduce_norm import *
from .boxes import *
from .checkpoint import load_ckpt, save_checkpoint
from .compat import meshgrid
from .demo_utils import *
from .dist import *
from .ema import *
Expand Down
15 changes: 15 additions & 0 deletions yolox/utils/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

import torch

_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]

__all__ = ["meshgrid"]


def meshgrid(*tensors):
if _TORCH_VER >= [1, 10]:
return torch.meshgrid(*tensors, indexing="ij")
else:
return torch.meshgrid(*tensors)