Skip to content

Commit

Permalink
Add DWConvClass() (ultralytics#4274)
Browse files Browse the repository at this point in the history
* Add `DWConvClass()`

* Cleanup

* Cleanup2
  • Loading branch information
glenn-jocher committed Aug 2, 2021
1 parent 23414d3 commit 98e61a6
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
11 changes: 9 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def autopad(k, p=None): # kernel, padding


def DWConv(c1, c2, k=1, s=1, act=True):
# Depth-wise convolution
# Depth-wise convolution function
return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)


Expand All @@ -45,10 +45,17 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k
def forward(self, x):
return self.act(self.bn(self.conv(x)))

def fuseforward(self, x):
def forward_fuse(self, x):
return self.act(self.conv(x))


class DWConvClass(Conv):
# Depth-wise convolution class
def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super().__init__(c1, c2, k, s, act)
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k), groups=math.gcd(c1, c2), bias=False)


class TransformerLayer(nn.Module):
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
def __init__(self, c, num_heads):
Expand Down
2 changes: 1 addition & 1 deletion models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def forward(self, x):


class MixConv2d(nn.Module):
# Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
# Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
super().__init__()
groups = len(k)
Expand Down
4 changes: 2 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ def _print_biases(self):
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
LOGGER.info('Fusing layers... ')
for m in self.model.modules():
if type(m) is Conv and hasattr(m, 'bn'):
if isinstance(m, (Conv, DWConvClass)) and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.fuseforward # update forward
m.forward = m.forward_fuse # update forward
self.info()
return self

Expand Down

0 comments on commit 98e61a6

Please sign in to comment.