diff --git a/models/common.py b/models/common.py index e6b7b5182283..30e7319f98a0 100644 --- a/models/common.py +++ b/models/common.py @@ -30,7 +30,7 @@ def autopad(k, p=None): # kernel, padding def DWConv(c1, c2, k=1, s=1, act=True): - # Depth-wise convolution + # Depth-wise convolution function return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act) @@ -45,10 +45,17 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k def forward(self, x): return self.act(self.bn(self.conv(x))) - def fuseforward(self, x): + def forward_fuse(self, x): return self.act(self.conv(x)) +class DWConvClass(Conv): + # Depth-wise convolution class + def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__(c1, c2, k, s, act) + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k), groups=math.gcd(c1, c2), bias=False) + + class TransformerLayer(nn.Module): # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance) def __init__(self, c, num_heads): diff --git a/models/experimental.py b/models/experimental.py index 276ca954b173..581c7b14b61e 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -72,7 +72,7 @@ def forward(self, x): class MixConv2d(nn.Module): - # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 + # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595 def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): super().__init__() groups = len(k) diff --git a/models/yolo.py b/models/yolo.py index 2e7a20f813e2..9f05c8329f38 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -202,10 +202,10 @@ def _print_biases(self): def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers LOGGER.info('Fusing layers... ') for m in self.model.modules(): - if type(m) is Conv and hasattr(m, 'bn'): + if isinstance(m, (Conv, DWConvClass)) and hasattr(m, 'bn'): m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm - m.forward = m.fuseforward # update forward + m.forward = m.forward_fuse # update forward self.info() return self