Skip to content

Commit

Permalink
Fixing normalization in dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Jun 25, 2021
1 parent e52771e commit a7ebfdf
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions yolort/models/transform.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified by Zhiqiang Wang (me@zhiqwang.com)
# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved.
import math
import torch
from torch import nn, Tensor
import torch.nn.functional as F

import torchvision
from torchvision.ops import box_convert

from typing import Dict, Optional, List, Tuple

Expand Down Expand Up @@ -146,7 +147,7 @@ def resize(
return image, target

bbox = target["boxes"]
bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
bbox = normalize_boxes(bbox, (h, w))
target["boxes"] = bbox

return image, target
Expand Down Expand Up @@ -297,3 +298,17 @@ def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -
ymin = ymin * ratio_height
ymax = ymax * ratio_height
return torch.stack((xmin, ymin, xmax, ymax), dim=1)


def normalize_boxes(boxes: Tensor, original_size: List[int]) -> Tensor:
height = torch.tensor(original_size[0], dtype=torch.float32, device=boxes.device)
width = torch.tensor(original_size[1], dtype=torch.float32, device=boxes.device)
xmin, ymin, xmax, ymax = boxes.unbind(1)

xmin = xmin / width
xmax = xmax / width
ymin = ymin / height
ymax = ymax / height
boxes = torch.stack((xmin, ymin, xmax, ymax), dim=1)
# Convert xyxy to cxcywh
return box_convert(boxes, in_fmt='xyxy', out_fmt='cxcywh')

0 comments on commit a7ebfdf

Please sign in to comment.