Skip to content

Commit

Permalink
Add DWConvTranspose2d() module (#7881)
Browse files Browse the repository at this point in the history
* Add DWConvTranspose2d() module

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add DWConvTranspose2d() module

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

* Fix

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] committed May 20, 2022
1 parent a9a92ae commit 5774a15
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 12 deletions.
6 changes: 6 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)


class DWConvTranspose2d(nn.ConvTranspose2d):
# Depth-wise transpose convolution class
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))


class TransformerLayer(nn.Module):
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
def __init__(self, c, num_heads):
Expand Down
47 changes: 36 additions & 11 deletions models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
import torch.nn as nn
from tensorflow import keras

from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, Focus, autopad
from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
DWConvTranspose2d, Focus, autopad)
from models.experimental import MixConv2d, attempt_load
from models.yolo import Detect
from utils.activations import SiLU
Expand Down Expand Up @@ -108,6 +109,29 @@ def call(self, inputs):
return self.act(self.bn(self.conv(inputs)))


class TFDWConvTranspose2d(keras.layers.Layer):
# Depthwise ConvTranspose2d
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None):
# ch_in, ch_out, weights, kernel, stride, padding, groups
super().__init__()
assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels'
assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1'
weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy()
self.c1 = c1
self.conv = [
keras.layers.Conv2DTranspose(filters=1,
kernel_size=k,
strides=s,
padding='VALID',
output_padding=p2,
use_bias=True,
kernel_initializer=keras.initializers.Constant(weight[..., i:i + 1]),
bias_initializer=keras.initializers.Constant(bias[i])) for i in range(c1)]

def call(self, inputs):
return tf.concat([m(x) for m, x in zip(self.conv, tf.split(inputs, self.c1, 3))], 3)[:, 1:-1, 1:-1]


class TFFocus(keras.layers.Layer):
# Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
Expand Down Expand Up @@ -152,15 +176,14 @@ class TFConv2d(keras.layers.Layer):
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
super().__init__()
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
self.conv = keras.layers.Conv2D(
c2,
k,
s,
'VALID',
use_bias=bias,
kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None,
)
self.conv = keras.layers.Conv2D(filters=c2,
kernel_size=k,
strides=s,
padding='VALID',
use_bias=bias,
kernel_initializer=keras.initializers.Constant(
w.weight.permute(2, 3, 1, 0).numpy()),
bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None)

def call(self, inputs):
return self.conv(inputs)
Expand Down Expand Up @@ -340,7 +363,9 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
pass

n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3x]:
if m in [
nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3x]:
c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2

Expand Down
2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)

n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, C3x):
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x):
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
Expand Down

0 comments on commit 5774a15

Please sign in to comment.