Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/docstrings' into docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Feb 25, 2024
2 parents 2fa2a5b + 3ca48a8 commit 06bca84
Showing 1 changed file with 80 additions and 19 deletions.
99 changes: 80 additions & 19 deletions models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ def call(self, inputs):
class TFPad(keras.layers.Layer):
# Pad inputs in spatial dimensions 1 and 2
def __init__(self, pad):
"""Initializes a padding layer for spatial dimensions 1 and 2 with specified padding, supporting both int and tuple inputs. Inputs are """
"""
Initializes a padding layer for spatial dimensions 1 and 2 with specified padding, supporting both int and tuple
inputs.
Inputs are
"""
super().__init__()
if isinstance(pad, int):
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
Expand All @@ -84,7 +89,12 @@ def call(self, inputs):
class TFConv(keras.layers.Layer):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
"""Initializes a standard convolution layer with optional batch normalization and activation; supports only group=1. Inputs are ch_in, ch_out, weights, kernel, stride, padding, groups."""
"""
Initializes a standard convolution layer with optional batch normalization and activation; supports only
group=1.
Inputs are ch_in, ch_out, weights, kernel, stride, padding, groups.
"""
super().__init__()
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
Expand All @@ -110,7 +120,12 @@ def call(self, inputs):
class TFDWConv(keras.layers.Layer):
# Depthwise convolution
def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
"""Initializes a depthwise convolution layer with optional batch normalization and activation for TensorFlow models. Input are ch_in, ch_out, weights, kernel, stride, padding, groups."""
"""
Initializes a depthwise convolution layer with optional batch normalization and activation for TensorFlow
models.
Input are ch_in, ch_out, weights, kernel, stride, padding, groups.
"""
super().__init__()
assert c2 % c1 == 0, f"TFDWConv() output={c2} must be a multiple of input={c1} channels"
conv = keras.layers.DepthwiseConv2D(
Expand All @@ -134,7 +149,11 @@ def call(self, inputs):
class TFDWConvTranspose2d(keras.layers.Layer):
# Depthwise ConvTranspose2d
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None):
"""Initializes depthwise ConvTranspose2D layer with specific channel, kernel, stride, and padding settings. Inputs are ch_in, ch_out, weights, kernel, stride, padding, groups."""
"""
Initializes depthwise ConvTranspose2D layer with specific channel, kernel, stride, and padding settings.
Inputs are ch_in, ch_out, weights, kernel, stride, padding, groups.
"""
super().__init__()
assert c1 == c2, f"TFDWConv() output={c2} must be equal to input={c1} channels"
assert k == 4 and p1 == 1, "TFDWConv() only valid for k=4 and p1=1"
Expand Down Expand Up @@ -162,7 +181,12 @@ def call(self, inputs):
class TFFocus(keras.layers.Layer):
# Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
"""Initializes TFFocus layer to focus width and height information into channel space with custom convolution parameters. Inputs are ch_in, ch_out, kernel, stride, padding, groups."""
"""
Initializes TFFocus layer to focus width and height information into channel space with custom convolution
parameters.
Inputs are ch_in, ch_out, kernel, stride, padding, groups.
"""
super().__init__()
self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)

Expand All @@ -182,7 +206,9 @@ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out,
self.add = shortcut and c1 == c2

def call(self, inputs):
"""Performs forward pass; if shortcut is True & input/output channels match, adds input to the convolution result."""
"""Performs forward pass; if shortcut is True & input/output channels match, adds input to the convolution
result.
"""
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))


Expand All @@ -204,7 +230,9 @@ def call(self, inputs):
class TFConv2d(keras.layers.Layer):
# Substitution for PyTorch nn.Conv2D
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
"""Initializes a TensorFlow 2D convolution layer, mimicking PyTorch's nn.Conv2D functionality for given filter sizes and stride."""
"""Initializes a TensorFlow 2D convolution layer, mimicking PyTorch's nn.Conv2D functionality for given filter
sizes and stride.
"""
super().__init__()
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
self.conv = keras.layers.Conv2D(
Expand All @@ -225,7 +253,12 @@ def call(self, inputs):
class TFBottleneckCSP(keras.layers.Layer):
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
"""Initializes CSP bottleneck layer with specified channel sizes, count, shortcut option, groups, and expansion ratio. Inputs are ch_in, ch_out, number, shortcut, groups, expansion."""
"""
Initializes CSP bottleneck layer with specified channel sizes, count, shortcut option, groups, and expansion
ratio.
Inputs are ch_in, ch_out, number, shortcut, groups, expansion.
"""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
Expand All @@ -237,7 +270,9 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])

def call(self, inputs):
"""Processes input through the model layers, concatenates, normalizes, activates, and reduces the output dimensions."""
"""Processes input through the model layers, concatenates, normalizes, activates, and reduces the output
dimensions.
"""
y1 = self.cv3(self.m(self.cv1(inputs)))
y2 = self.cv2(inputs)
return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
Expand All @@ -246,7 +281,11 @@ def call(self, inputs):
class TFC3(keras.layers.Layer):
# CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
"""Initializes CSP Bottleneck with 3 convolutions, supporting optional shortcuts and group convolutions. Inputs are ch_in, ch_out, number, shortcut, groups, expansion."""
"""
Initializes CSP Bottleneck with 3 convolutions, supporting optional shortcuts and group convolutions.
Inputs are ch_in, ch_out, number, shortcut, groups, expansion.
"""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
Expand All @@ -255,14 +294,22 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])

def call(self, inputs):
"""Processes input through a sequence of transformations for object detection (YOLOv5). See https://github.com/ultralytics/yolov5."""
"""
Processes input through a sequence of transformations for object detection (YOLOv5).
See https://github.com/ultralytics/yolov5.
"""
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))


class TFC3x(keras.layers.Layer):
# 3 module with cross-convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
"""Initializes layer with cross-convolutions for enhanced feature extraction in object detection models. Inputs are ch_in, ch_out, number, shortcut, groups, expansion."""
"""
Initializes layer with cross-convolutions for enhanced feature extraction in object detection models.
Inputs are ch_in, ch_out, number, shortcut, groups, expansion.
"""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
Expand Down Expand Up @@ -296,15 +343,19 @@ def call(self, inputs):
class TFSPPF(keras.layers.Layer):
# Spatial pyramid pooling-Fast layer
def __init__(self, c1, c2, k=5, w=None):
"""Initializes a fast spatial pyramid pooling layer with customizable in/out channels, kernel size, and weights."""
"""Initializes a fast spatial pyramid pooling layer with customizable in/out channels, kernel size, and
weights.
"""
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding="SAME")

def call(self, inputs):
"""Executes the model's forward pass, concatenating input features with three max-pooled versions before final convolution."""
"""Executes the model's forward pass, concatenating input features with three max-pooled versions before final
convolution.
"""
x = self.cv1(inputs)
y1 = self.m(x)
y2 = self.m(y1)
Expand Down Expand Up @@ -365,7 +416,9 @@ def _make_grid(nx=20, ny=20):
class TFSegment(TFDetect):
# YOLOv5 Segment head for segmentation models
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w=None):
"""Initializes YOLOv5 Segment head with specified channel depths, anchors, and input size for segmentation models."""
"""Initializes YOLOv5 Segment head with specified channel depths, anchors, and input size for segmentation
models.
"""
super().__init__(nc, anchors, ch, imgsz, w)
self.nm = nm # number of masks
self.npr = npr # number of protos
Expand All @@ -385,7 +438,9 @@ def call(self, x):

class TFProto(keras.layers.Layer):
def __init__(self, c1, c_=256, c2=32, w=None):
"""Initializes TFProto layer with convolutional and upsampling layers for feature extraction and transformation."""
"""Initializes TFProto layer with convolutional and upsampling layers for feature extraction and
transformation.
"""
super().__init__()
self.cv1 = TFConv(c1, c_, k=3, w=w.cv1)
self.upsample = TFUpsample(None, scale_factor=2, mode="nearest")
Expand Down Expand Up @@ -566,7 +621,9 @@ def predict(

@staticmethod
def _xywh2xyxy(xywh):
"""Converts bounding box format from [x, y, w, h] to [x1, y1, x2, y2], where xy1=top-left and xy2=bottom-right."""
"""Converts bounding box format from [x, y, w, h] to [x1, y1, x2, y2], where xy1=top-left and xy2=bottom-
right.
"""
x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)

Expand Down Expand Up @@ -628,7 +685,9 @@ def activations(act=nn.SiLU):


def representative_dataset_gen(dataset, ncalib=100):
"""Generates a representative dataset for calibration by yielding transformed numpy arrays from the input dataset."""
"""Generates a representative dataset for calibration by yielding transformed numpy arrays from the input
dataset.
"""
for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
im = np.transpose(img, [1, 2, 0])
im = np.expand_dims(im, axis=0).astype(np.float32)
Expand Down Expand Up @@ -664,7 +723,9 @@ def run(


def parse_opt():
"""Parses and returns command-line options for model inference, including weights path, image size, batch size, and dynamic batching."""
"""Parses and returns command-line options for model inference, including weights path, image size, batch size, and
dynamic batching.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--weights", type=str, default=ROOT / "yolov5s.pt", help="weights path")
parser.add_argument("--imgsz", "--img", "--img-size", nargs="+", type=int, default=[640], help="inference size h,w")
Expand Down

0 comments on commit 06bca84

Please sign in to comment.