Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional transforms argument to LoadStreams() #9105

Merged
merged 4 commits into from
Aug 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 25 additions & 29 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __next__(self):
s = f'image {self.count}/{self.nf} {path}: '

if self.transforms:
im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # classify transforms
im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # transforms
else:
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
Expand Down Expand Up @@ -289,30 +289,28 @@ def __next__(self):
raise StopIteration

# Read frame
ret_val, img0 = self.cap.read()
img0 = cv2.flip(img0, 1) # flip left-right
ret_val, im0 = self.cap.read()
im0 = cv2.flip(im0, 1) # flip left-right

# Print
assert ret_val, f'Camera Error {self.pipe}'
img_path = 'webcam.jpg'
s = f'webcam {self.count}: '

# Padded resize
img = letterbox(img0, self.img_size, stride=self.stride)[0]
# Process
im = letterbox(im0, self.img_size, stride=self.stride)[0] # resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous

# Convert
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)

return img_path, img, img0, None, s
return img_path, im, im0, None, s

def __len__(self):
return 0


class LoadStreams:
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None):
self.mode = 'stream'
self.img_size = img_size
self.stride = stride
Expand All @@ -326,7 +324,6 @@ def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
n = len(sources)
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
self.sources = [clean_str(x) for x in sources] # clean source names for later
self.auto = auto
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
st = f'{i + 1}/{n}: {s}... '
Expand All @@ -353,8 +350,10 @@ def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
LOGGER.info('') # newline

# check for common shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
self.auto = auto and self.rect
self.transforms = transforms # optional
if not self.rect:
LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')

Expand Down Expand Up @@ -385,18 +384,15 @@ def __next__(self):
cv2.destroyAllWindows()
raise StopIteration

# Letterbox
img0 = self.imgs.copy()
img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]

# Stack
img = np.stack(img, 0)

# Convert
img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
img = np.ascontiguousarray(img)
im0 = self.imgs.copy()
if self.transforms:
im = np.stack([self.transforms(cv2.cvtColor(x, cv2.COLOR_BGR2RGB)) for x in im0]) # transforms
else:
im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
im = np.ascontiguousarray(im) # contiguous

return self.sources, img, img0, None, ''
return self.sources, im, im0, None, ''

def __len__(self):
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
Expand Down Expand Up @@ -836,7 +832,7 @@ def collate_fn(batch):

@staticmethod
def collate_fn4(batch):
img, label, path, shapes = zip(*batch) # transposed
im, label, path, shapes = zip(*batch) # transposed
n = len(shapes) // 4
im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]

Expand All @@ -846,13 +842,13 @@ def collate_fn4(batch):
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
i *= 4
if random.random() < 0.5:
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
align_corners=False)[0].type(img[i].type())
im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
align_corners=False)[0].type(im[i].type())
lb = label[i]
else:
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
im4.append(im)
im4.append(im1)
label4.append(lb)

for i, lb in enumerate(label4):
Expand Down