Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/docstrings' into docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Feb 25, 2024
2 parents 92f72fe + 08299c7 commit 1a7d425
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def process_batch(self, detections, labels):
self.matrix[dc, self.nc] += 1 # predicted background

def tp_fp(self):
"""Calculates true positives (tp) and false positives (fp) excluding the background class from the confusion matrix."""
"""Calculates true positives (tp) and false positives (fp) excluding the background class from the confusion
matrix.
"""
tp = self.matrix.diagonal() # true positives
fp = self.matrix.sum(1) - tp # false positives
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
Expand Down Expand Up @@ -226,7 +228,11 @@ def print(self):


def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
"""Calculates IoU, GIoU, DIoU, or CIoU between two boxes, supporting xywh/xyxy formats. Input shapes are box1(1,4) to box2(n,4)."""
"""
Calculates IoU, GIoU, DIoU, or CIoU between two boxes, supporting xywh/xyxy formats.
Input shapes are box1(1,4) to box2(n,4).
"""

# Get the coordinates of bounding boxes
if xywh: # transform from xywh to xyxy
Expand Down Expand Up @@ -316,7 +322,9 @@ def bbox_ioa(box1, box2, eps=1e-7):


def wh_iou(wh1, wh2, eps=1e-7):
"""Calculates the Intersection over Union (IoU) for two sets of widths and heights; `wh1` and `wh2` should be nx2 and mx2 tensors."""
"""Calculates the Intersection over Union (IoU) for two sets of widths and heights; `wh1` and `wh2` should be nx2
and mx2 tensors.
"""
wh1 = wh1[:, None] # [N,1,2]
wh2 = wh2[None] # [1,M,2]
inter = torch.min(wh1, wh2).prod(2) # [N,M]
Expand All @@ -328,7 +336,9 @@ def wh_iou(wh1, wh2, eps=1e-7):

@threaded
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=()):
"""Plots precision-recall curve, optionally per class, saving to `save_dir`; `px`, `py` are lists, `ap` is Nx2 array, `names` optional."""
"""Plots precision-recall curve, optionally per class, saving to `save_dir`; `px`, `py` are lists, `ap` is Nx2
array, `names` optional.
"""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)

Expand Down

0 comments on commit 1a7d425

Please sign in to comment.