diff --git a/models/common.py b/models/common.py index 7ac3a4a29672..c30c8ee94777 100644 --- a/models/common.py +++ b/models/common.py @@ -28,18 +28,20 @@ from utils.torch_utils import copy_attr, smart_inference_mode -def autopad(k, p=None): # kernel, padding - # Pad to 'same' +def autopad(k, p=None, d=1): # kernel, padding, dilation + # Pad to 'same' shape outputs + if d > 1: + k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad return p class Conv(nn.Module): - # Standard convolution - def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation) + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): super().__init__() - self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) @@ -51,13 +53,13 @@ def forward_fuse(self, x): class DWConv(Conv): - # Depth-wise convolution class + # Depth-wise convolution def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act) class DWConvTranspose2d(nn.ConvTranspose2d): - # Depth-wise transpose convolution class + # Depth-wise transpose convolution def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2)) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index abf0bbc19a98..8a3366ca3e27 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -251,6 +251,7 @@ def fuse_conv_and_bn(conv, bn): kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, + dilation=conv.dilation, groups=conv.groups, bias=True).requires_grad_(False).to(conv.weight.device)