From a7ebfdfd132269fb1db526c18a6f244dd2d17bb0 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Fri, 25 Jun 2021 04:11:51 -0400 Subject: [PATCH] Fixing normalization in dataloader --- yolort/models/transform.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/yolort/models/transform.py b/yolort/models/transform.py index d9bf91b2..9ed5e764 100644 --- a/yolort/models/transform.py +++ b/yolort/models/transform.py @@ -1,11 +1,12 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -# Modified by Zhiqiang Wang (me@zhiqwang.com) +# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. import math import torch from torch import nn, Tensor import torch.nn.functional as F import torchvision +from torchvision.ops import box_convert from typing import Dict, Optional, List, Tuple @@ -146,7 +147,7 @@ def resize( return image, target bbox = target["boxes"] - bbox = resize_boxes(bbox, (h, w), image.shape[-2:]) + bbox = normalize_boxes(bbox, (h, w)) target["boxes"] = bbox return image, target @@ -297,3 +298,17 @@ def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) - ymin = ymin * ratio_height ymax = ymax * ratio_height return torch.stack((xmin, ymin, xmax, ymax), dim=1) + + +def normalize_boxes(boxes: Tensor, original_size: List[int]) -> Tensor: + height = torch.tensor(original_size[0], dtype=torch.float32, device=boxes.device) + width = torch.tensor(original_size[1], dtype=torch.float32, device=boxes.device) + xmin, ymin, xmax, ymax = boxes.unbind(1) + + xmin = xmin / width + xmax = xmax / width + ymin = ymin / height + ymax = ymax / height + boxes = torch.stack((xmin, ymin, xmax, ymax), dim=1) + # Convert xyxy to cxcywh + return box_convert(boxes, in_fmt='xyxy', out_fmt='cxcywh')