Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/docstrings' into docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Feb 25, 2024
2 parents 2de9026 + 791f5cf commit 5040f3a
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,21 @@ def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True):
self.detect = Detect.forward

def forward(self, x):
"""Processes input through the network, returning detections and prototypes; adjusts output based on training/export mode."""
"""Processes input through the network, returning detections and prototypes; adjusts output based on
training/export mode.
"""
p = self.proto(x[0])
x = self.detect(self, x)
return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1])


class BaseModel(nn.Module):
"""YOLOv5 base model."""

def forward(self, x, profile=False, visualize=False):
"""Executes a single-scale inference or training pass on the YOLOv5 base model, with options for profiling and visualization."""
"""Executes a single-scale inference or training pass on the YOLOv5 base model, with options for profiling and
visualization.
"""
return self._forward_once(x, profile, visualize) # single-scale inference, train

def _forward_once(self, x, profile=False, visualize=False):
Expand Down Expand Up @@ -192,7 +197,9 @@ def info(self, verbose=False, img_size=640): # print model information
model_info(self, verbose, img_size)

def _apply(self, fn):
"""Applies transformations like to(), cpu(), cuda(), half() to model tensors excluding parameters or registered buffers."""
"""Applies transformations like to(), cpu(), cuda(), half() to model tensors excluding parameters or registered
buffers.
"""
self = super()._apply(fn)
m = self.model[-1] # Detect()
if isinstance(m, (Detect, Segment)):
Expand Down Expand Up @@ -284,7 +291,9 @@ def _descale_pred(self, p, flips, scale, img_size):
return p

def _clip_augmented(self, y):
"""Clips augmented inference tails for YOLOv5 models, affecting first and last tensors based on grid points and layer counts."""
"""Clips augmented inference tails for YOLOv5 models, affecting first and last tensors based on grid points and
layer counts.
"""
nl = self.model[-1].nl # number of detection layers (P3-P5)
g = sum(4**x for x in range(nl)) # grid points
e = 1 # exclude layer count
Expand Down Expand Up @@ -324,7 +333,9 @@ def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, nu
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)

def _from_detection_model(self, model, nc=1000, cutoff=10):
"""Creates a classification model from a YOLOv5 detection model, slicing at `cutoff` and adding a classification layer."""
"""Creates a classification model from a YOLOv5 detection model, slicing at `cutoff` and adding a classification
layer.
"""
if isinstance(model, DetectMultiBackend):
model = model.model # unwrap DetectMultiBackend
model.model = model.model[:cutoff] # backbone
Expand Down

0 comments on commit 5040f3a

Please sign in to comment.