Skip to content

Commit

Permalink
New TensorFlow TFCrossConv() module (#7827)
Browse files Browse the repository at this point in the history
* New TensorFlow `TFCrossConv()` module

* Move from experimental to common

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add C3x

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add to C3x to yolo.py

* Add to C3x to tf.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* TFC3x bug fix

* TFC3x bug fix

* TFC3x bug fix

* Add TFDWConv g==c1==c2 check

* Add comment

* Update tf.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] committed May 16, 2022
1 parent d29df68 commit fb7fa5b
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 28 deletions.
25 changes: 23 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
def autopad(k, p=None): # kernel, padding
# Pad to 'same'
if p is None:
p = k // 2 if isinstance(k, int) else (x // 2 for x in k) # auto-pad
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p


Expand Down Expand Up @@ -124,6 +124,20 @@ def forward(self, x):
return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))


class CrossConv(nn.Module):
# Cross Convolution Downsample
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
# ch_in, ch_out, kernel, stride, groups, expansion, shortcut
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, (1, k), (1, s))
self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
self.add = shortcut and c1 == c2

def forward(self, x):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class C3(nn.Module):
# CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
Expand All @@ -133,12 +147,19 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
# self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))

def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))


class C3x(C3):
# C3 module with cross-convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e)
self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))


class C3TR(C3):
# C3 module with TransformerBlock()
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
Expand Down
14 changes: 0 additions & 14 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,6 @@
from utils.downloads import attempt_download


class CrossConv(nn.Module):
# Cross Convolution Downsample
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
# ch_in, ch_out, kernel, stride, groups, expansion, shortcut
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, (1, k), (1, s))
self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
self.add = shortcut and c1 == c2

def forward(self, x):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class Sum(nn.Module):
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
def __init__(self, n, weight=False): # n: number of inputs
Expand Down
49 changes: 39 additions & 10 deletions models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import torch.nn as nn
from tensorflow import keras

from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, Concat, Conv, DWConv, Focus, autopad
from models.experimental import CrossConv, MixConv2d, attempt_load
from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, Focus, autopad
from models.experimental import MixConv2d, attempt_load
from models.yolo import Detect
from utils.activations import SiLU
from utils.general import LOGGER, make_divisible, print_args
Expand All @@ -50,10 +50,13 @@ def call(self, inputs):


class TFPad(keras.layers.Layer):

# Pad inputs in spatial dimensions 1 and 2
def __init__(self, pad):
super().__init__()
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
if isinstance(pad, int):
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
else: # tuple/list
self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]])

def call(self, inputs):
return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
Expand All @@ -65,10 +68,8 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
# ch_in, ch_out, weights, kernel, stride, padding, groups
super().__init__()
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch

conv = keras.layers.Conv2D(
filters=c2,
kernel_size=k,
Expand All @@ -90,8 +91,7 @@ class TFDWConv(keras.layers.Layer):
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
# ch_in, ch_out, weights, kernel, stride, padding, groups
super().__init__()
assert isinstance(k, int), "Convolution with multiple kernels are not allowed."

assert g == c1 == c2, f'TFDWConv() groups={g} must equal input={c1} and output={c2} channels'
conv = keras.layers.DepthwiseConv2D(
kernel_size=k,
strides=s,
Expand Down Expand Up @@ -133,6 +133,19 @@ def call(self, inputs):
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))


class TFCrossConv(keras.layers.Layer):
# Cross Convolution
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None):
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = TFConv(c1, c_, (1, k), (1, s), w=w.cv1)
self.cv2 = TFConv(c_, c2, (k, 1), (s, 1), g=g, w=w.cv2)
self.add = shortcut and c1 == c2

def call(self, inputs):
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))


class TFConv2d(keras.layers.Layer):
# Substitution for PyTorch nn.Conv2D
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
Expand Down Expand Up @@ -187,6 +200,22 @@ def call(self, inputs):
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))


class TFC3x(keras.layers.Layer):
# 3 module with cross-convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
# ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
self.m = keras.Sequential([
TFCrossConv(c_, c_, k=3, s=1, g=g, e=1.0, shortcut=shortcut, w=w.m[j]) for j in range(n)])

def call(self, inputs):
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))


class TFSPP(keras.layers.Layer):
# Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13), w=None):
Expand Down Expand Up @@ -310,12 +339,12 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
pass

n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3x]:
c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2

args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3]:
if m in [BottleneckCSP, C3, C3x]:
args.insert(2, n)
n = 1
elif m is nn.BatchNorm2d:
Expand Down
4 changes: 2 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,13 @@ def parse_model(d, ch): # model_dict, input_channels(3)

n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost):
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, C3x):
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)

args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
if m in [BottleneckCSP, C3, C3TR, C3Ghost, C3x]:
args.insert(2, n) # number of repeats
n = 1
elif m is nn.BatchNorm2d:
Expand Down

0 comments on commit fb7fa5b

Please sign in to comment.