diff --git a/models/tf.py b/models/tf.py index 7e0d61729e36..3a26256ea2b1 100644 --- a/models/tf.py +++ b/models/tf.py @@ -70,25 +70,38 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch conv = keras.layers.Conv2D( - c2, - k, - s, - 'SAME' if s == 1 else 'VALID', - use_bias=False if hasattr(w, 'bn') else True, + filters=c2, + kernel_size=k, + strides=s, + padding='SAME' if s == 1 else 'VALID', + use_bias=not hasattr(w, 'bn'), kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()), bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy())) self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv]) self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity + self.act = activations(w.act) if act else tf.identity + + def call(self, inputs): + return self.act(self.bn(self.conv(inputs))) - # YOLOv5 activations - if isinstance(w.act, nn.LeakyReLU): - self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity - elif isinstance(w.act, nn.Hardswish): - self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity - elif isinstance(w.act, (nn.SiLU, SiLU)): - self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity - else: - raise Exception(f'no matching TensorFlow activation found for {w.act}') + +class TFDWConv(keras.layers.Layer): + # Depthwise convolution + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): + # ch_in, ch_out, weights, kernel, stride, padding, groups + super().__init__() + assert isinstance(k, int), "Convolution with multiple kernels are not allowed." + + conv = keras.layers.DepthwiseConv2D( + kernel_size=k, + strides=s, + padding='SAME' if s == 1 else 'VALID', + use_bias=not hasattr(w, 'bn'), + kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()), + bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy())) + self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv]) + self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity + self.act = activations(w.act) if act else tf.identity def call(self, inputs): return self.act(self.bn(self.conv(inputs))) @@ -103,10 +116,8 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c) # inputs = inputs / 255 # normalize 0-255 to 0-1 - return self.conv( - tf.concat( - [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]], - 3)) + inputs = [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]] + return self.conv(tf.concat(inputs, 3)) class TFBottleneck(keras.layers.Layer): @@ -439,6 +450,18 @@ def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS return padded_boxes, padded_scores, padded_classes, valid_detections +def activations(act=nn.SiLU): + # Returns TF activation from input PyTorch activation + if isinstance(act, nn.LeakyReLU): + return lambda x: keras.activations.relu(x, alpha=0.1) + elif isinstance(act, nn.Hardswish): + return lambda x: x * tf.nn.relu6(x + 3) * 0.166666667 + elif isinstance(act, (nn.SiLU, SiLU)): + return lambda x: keras.activations.swish(x) + else: + raise Exception(f'no matching TensorFlow activation found for PyTorch activation {act}') + + def representative_dataset_gen(dataset, ncalib=100): # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):