diff --git a/models/common.py b/models/common.py index 9d181ed7a40a..00641eaa8d15 100644 --- a/models/common.py +++ b/models/common.py @@ -31,7 +31,7 @@ def autopad(k, p=None): # kernel, padding # Pad to 'same' if p is None: - p = k // 2 if isinstance(k, int) else (x // 2 for x in k) # auto-pad + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad return p @@ -124,6 +124,20 @@ def forward(self, x): return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1)))) +class CrossConv(nn.Module): + # Cross Convolution Downsample + def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): + # ch_in, ch_out, kernel, stride, groups, expansion, shortcut + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, (1, k), (1, s)) + self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + class C3(nn.Module): # CSP Bottleneck with 3 convolutions def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion @@ -133,12 +147,19 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu self.cv2 = Conv(c1, c_, 1, 1) self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) - # self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n))) def forward(self, x): return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) +class C3x(C3): + # C3 module with cross-convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) + self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n))) + + class C3TR(C3): # C3 module with TransformerBlock() def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): diff --git a/models/experimental.py b/models/experimental.py index 1ffe0fcb5971..7bf249e80984 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -12,20 +12,6 @@ from utils.downloads import attempt_download -class CrossConv(nn.Module): - # Cross Convolution Downsample - def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): - # ch_in, ch_out, kernel, stride, groups, expansion, shortcut - super().__init__() - c_ = int(c2 * e) # hidden channels - self.cv1 = Conv(c1, c_, (1, k), (1, s)) - self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) - self.add = shortcut and c1 == c2 - - def forward(self, x): - return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) - - class Sum(nn.Module): # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 def __init__(self, n, weight=False): # n: number of inputs diff --git a/models/tf.py b/models/tf.py index 3a26256ea2b1..b70e37488009 100644 --- a/models/tf.py +++ b/models/tf.py @@ -27,8 +27,8 @@ import torch.nn as nn from tensorflow import keras -from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, Concat, Conv, DWConv, Focus, autopad -from models.experimental import CrossConv, MixConv2d, attempt_load +from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, Focus, autopad +from models.experimental import MixConv2d, attempt_load from models.yolo import Detect from utils.activations import SiLU from utils.general import LOGGER, make_divisible, print_args @@ -50,10 +50,13 @@ def call(self, inputs): class TFPad(keras.layers.Layer): - + # Pad inputs in spatial dimensions 1 and 2 def __init__(self, pad): super().__init__() - self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]) + if isinstance(pad, int): + self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]) + else: # tuple/list + self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]]) def call(self, inputs): return tf.pad(inputs, self.pad, mode='constant', constant_values=0) @@ -65,10 +68,8 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): # ch_in, ch_out, weights, kernel, stride, padding, groups super().__init__() assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument" - assert isinstance(k, int), "Convolution with multiple kernels are not allowed." # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding) # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch - conv = keras.layers.Conv2D( filters=c2, kernel_size=k, @@ -90,8 +91,7 @@ class TFDWConv(keras.layers.Layer): def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): # ch_in, ch_out, weights, kernel, stride, padding, groups super().__init__() - assert isinstance(k, int), "Convolution with multiple kernels are not allowed." - + assert g == c1 == c2, f'TFDWConv() groups={g} must equal input={c1} and output={c2} channels' conv = keras.layers.DepthwiseConv2D( kernel_size=k, strides=s, @@ -133,6 +133,19 @@ def call(self, inputs): return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs)) +class TFCrossConv(keras.layers.Layer): + # Cross Convolution + def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None): + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = TFConv(c1, c_, (1, k), (1, s), w=w.cv1) + self.cv2 = TFConv(c_, c2, (k, 1), (s, 1), g=g, w=w.cv2) + self.add = shortcut and c1 == c2 + + def call(self, inputs): + return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs)) + + class TFConv2d(keras.layers.Layer): # Substitution for PyTorch nn.Conv2D def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None): @@ -187,6 +200,22 @@ def call(self, inputs): return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3)) +class TFC3x(keras.layers.Layer): + # 3 module with cross-convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None): + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1) + self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2) + self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3) + self.m = keras.Sequential([ + TFCrossConv(c_, c_, k=3, s=1, g=g, e=1.0, shortcut=shortcut, w=w.m[j]) for j in range(n)]) + + def call(self, inputs): + return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3)) + + class TFSPP(keras.layers.Layer): # Spatial pyramid pooling layer used in YOLOv3-SPP def __init__(self, c1, c2, k=(5, 9, 13), w=None): @@ -310,12 +339,12 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3) pass n = max(round(n * gd), 1) if n > 1 else n # depth gain - if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]: + if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3x]: c1, c2 = ch[f], args[0] c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 args = [c1, c2, *args[1:]] - if m in [BottleneckCSP, C3]: + if m in [BottleneckCSP, C3, C3x]: args.insert(2, n) n = 1 elif m is nn.BatchNorm2d: diff --git a/models/yolo.py b/models/yolo.py index 55356a6a9b44..9695ed7ff186 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -266,13 +266,13 @@ def parse_model(d, ch): # model_dict, input_channels(3) n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, - BottleneckCSP, C3, C3TR, C3SPP, C3Ghost): + BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, C3x): c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) args = [c1, c2, *args[1:]] - if m in [BottleneckCSP, C3, C3TR, C3Ghost]: + if m in [BottleneckCSP, C3, C3TR, C3Ghost, C3x]: args.insert(2, n) # number of repeats n = 1 elif m is nn.BatchNorm2d: