diff --git a/utils/datasets.py b/utils/datasets.py index d3714d745b88..eac0c7834308 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -23,8 +23,8 @@ from torch.utils.data import Dataset from tqdm import tqdm -from utils.general import check_requirements, check_file, check_dataset, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, \ - segment2box, segments2boxes, resample_segments, clean_str +from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \ + xyn2xy, segment2box, segments2boxes, resample_segments, clean_str from utils.torch_utils import torch_distributed_zero_first # Parameters @@ -192,7 +192,7 @@ def __next__(self): img = letterbox(img0, self.img_size, stride=self.stride)[0] # Convert - img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 + img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW img = np.ascontiguousarray(img) return path, img, img0, self.cap @@ -255,7 +255,7 @@ def __next__(self): img = letterbox(img0, self.img_size, stride=self.stride)[0] # Convert - img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 + img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW img = np.ascontiguousarray(img) return img_path, img, img0, None @@ -336,7 +336,7 @@ def __next__(self): img = np.stack(img, 0) # Convert - img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416 + img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB and BHWC to BCHW img = np.ascontiguousarray(img) return self.sources, img, img0, None @@ -552,9 +552,7 @@ def __getitem__(self, index): nL = len(labels) # number of labels if nL: - labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh - labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1 - labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1 + labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized if self.augment: # flip up-down diff --git a/utils/general.py b/utils/general.py index 555975f07c5d..6a5b42f374e6 100755 --- a/utils/general.py +++ b/utils/general.py @@ -393,6 +393,16 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): return y +def xyxy2xywhn(x, w=640, h=640): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center + y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center + y[:, 2] = (x[:, 2] - x[:, 0]) / w # width + y[:, 3] = (x[:, 3] - x[:, 1]) / h # height + return y + + def xyn2xy(x, w=640, h=640, padw=0, padh=0): # Convert normalized segments into pixel segments, shape (n,2) y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)