Skip to content

Commit

Permalink
Fixing module loading
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Mar 18, 2021
1 parent 023dfbe commit 12065f3
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.yolotr import darknet_pan_tr_backbone
from yolort.models.transformer import darknet_tan_backbone
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.box_head import YoloHead, PostProcess, SetCriterion

Expand Down Expand Up @@ -112,7 +112,7 @@ def _init_test_backbone_with_pan_tr(self):
backbone_name = 'darknet_s_r4_0'
depth_multiple = 0.33
width_multiple = 0.5
backbone_with_fpn_tr = darknet_pan_tr_backbone(backbone_name, depth_multiple, width_multiple)
backbone_with_fpn_tr = darknet_tan_backbone(backbone_name, depth_multiple, width_multiple)
return backbone_with_fpn_tr

def test_backbone_with_pan_tr(self):
Expand Down
2 changes: 1 addition & 1 deletion yolort/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def darknet_tan_backbone(
version: str = 'v4.0',
):
"""
Constructs a specified DarkNet backbone with PAN on top. Freezes the specified number of
Constructs a specified DarkNet backbone with TAN on top. Freezes the specified number of
layers in the backbone.
Examples::
Expand Down
10 changes: 5 additions & 5 deletions yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from torchvision.models.utils import load_state_dict_from_url

from .backbone_utils import darknet_pan_backbone
from .transformer import darknet_pan_tr_backbone
from .transformer import darknet_tan_backbone
from .anchor_utils import AnchorGenerator
from .box_head import YoloHead, SetCriterion, PostProcess

from typing import Tuple, Any, List, Dict, Optional

__all__ = ['YOLO', 'yolov5_darknet_pan_s_r31', 'yolov5_darknet_pan_m_r31', 'yolov5_darknet_pan_l_r31',
'yolov5_darknet_pan_s_r40', 'yolov5_darknet_pan_m_r40', 'yolov5_darknet_pan_l_r40',
'yolov5_darknet_pan_s_tr']
'yolov5_darknet_tan_s_r40']


class YOLO(nn.Module):
Expand Down Expand Up @@ -303,8 +303,8 @@ def yolov5_darknet_pan_l_r40(pretrained: bool = False, progress: bool = True, nu
pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs)


def yolov5_darknet_pan_s_tr(pretrained: bool = False, progress: bool = True, num_classes: int = 80,
**kwargs: Any) -> YOLO:
def yolov5_darknet_tan_s_r40(pretrained: bool = False, progress: bool = True, num_classes: int = 80,
**kwargs: Any) -> YOLO:
r"""yolov5 small with a transformer block model from
`"dingyiwei/yolov5" <https://github.com/ultralytics/yolov5/pull/2333>`_.
Args:
Expand All @@ -317,7 +317,7 @@ def yolov5_darknet_pan_s_tr(pretrained: bool = False, progress: bool = True, num
width_multiple = 0.5
version = 'v4.0'

backbone = darknet_pan_tr_backbone(backbone_name, depth_multiple, width_multiple, version=version)
backbone = darknet_tan_backbone(backbone_name, depth_multiple, width_multiple, version=version)

anchor_grids = [[10, 13, 16, 30, 33, 23],
[30, 61, 62, 45, 59, 119],
Expand Down

0 comments on commit 12065f3

Please sign in to comment.