Skip to content

Commit

Permalink
fix tf conversion in new v6 models (ultralytics#5153)
Browse files Browse the repository at this point in the history
* fix `tf` conversion in new v6 (ultralytics#5147)

* sort imports

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
YoniChechik and glenn-jocher committed Oct 12, 2021
1 parent 4ca216e commit f63dd29
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch.nn as nn
from tensorflow import keras

from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
from models.common import Bottleneck, BottleneckCSP, Concat, Conv, C3, DWConv, Focus, SPP, SPPF, autopad
from models.experimental import CrossConv, MixConv2d, attempt_load
from models.yolo import Detect
from utils.general import make_divisible, print_args, set_logging
Expand Down Expand Up @@ -183,6 +183,22 @@ def call(self, inputs):
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))


class TFSPPF(keras.layers.Layer):
# Spatial pyramid pooling-Fast layer
def __init__(self, c1, c2, k=5, w=None):
super(TFSPPF, self).__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')

def call(self, inputs):
x = self.cv1(inputs)
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))


class TFDetect(keras.layers.Layer):
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
super(TFDetect, self).__init__()
Expand Down Expand Up @@ -272,7 +288,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
pass

n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2

Expand Down

0 comments on commit f63dd29

Please sign in to comment.