diff --git a/models/export.py b/models/export.py index 6f8799e55593..3c04b07fdc95 100644 --- a/models/export.py +++ b/models/export.py @@ -9,13 +9,15 @@ import time from pathlib import Path -sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories - import torch import torch.nn as nn from torch.utils.mobile_optimizer import optimize_for_mobile -import models +FILE = Path(__file__).absolute() +sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path + +from models.common import Conv +from models.yolo import Detect from models.experimental import attempt_load from utils.activations import Hardswish, SiLU from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging @@ -56,12 +58,12 @@ def export(weights='./yolov5s.pt', # weights path model.train() if train else model.eval() # training mode = no Detect() layer grid construction for k, m in model.named_modules(): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility - if isinstance(m, models.common.Conv): # assign export-friendly activations + if isinstance(m, Conv): # assign export-friendly activations if isinstance(m.act, nn.Hardswish): m.act = Hardswish() elif isinstance(m.act, nn.SiLU): m.act = SiLU() - elif isinstance(m, models.yolo.Detect): + elif isinstance(m, Detect): m.inplace = inplace m.onnx_dynamic = dynamic # m.forward = m.forward_export # assign forward (optional) diff --git a/models/yolo.py b/models/yolo.py index 1a7be913023c..4a2514edd295 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -10,8 +10,8 @@ from copy import deepcopy from pathlib import Path -sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories -logger = logging.getLogger(__name__) +FILE = Path(__file__).absolute() +sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path from models.common import * from models.experimental import * @@ -25,6 +25,8 @@ except ImportError: thop = None +logger = logging.getLogger(__name__) + class Detect(nn.Module): stride = None # strides computed during build