-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
491 lines (403 loc) · 15.9 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
from __future__ import division, print_function, absolute_import
import copy
import numpy as np
import os.path as osp
import tarfile
import zipfile
import torch
from utils import read_image, download_url, mkdir_if_missing
class Dataset(object):
"""An abstract class representing a Dataset.
This is the base class for ``ImageDataset`` and ``VideoDataset``.
Args:
train (list): contains tuples of (img_path(s), pid, camid).
query (list): contains tuples of (img_path(s), pid, camid).
gallery (list): contains tuples of (img_path(s), pid, camid).
transform: transform function.
k_tfm (int): number of times to apply augmentation to an image
independently. If k_tfm > 1, the transform function will be
applied k_tfm times to an image. This variable will only be
useful for training and is currently valid for image datasets only.
mode (str): 'train', 'query' or 'gallery'.
combineall (bool): combines train, query and gallery in a
dataset for training.
verbose (bool): show information.
"""
# junk_pids contains useless person IDs, e.g. background,
# false detections, distractors. These IDs will be ignored
# when combining all images in a dataset for training, i.e.
# combineall=True
_junk_pids = []
# Some datasets are only used for training, like CUHK-SYSU
# In this case, "combineall=True" is not used for them
_train_only = False
def __init__(
self,
train,
query,
gallery,
transform=None,
k_tfm=1,
mode='train',
combineall=False,
verbose=True,
**kwargs
):
# extend 3-tuple (img_path(s), pid, camid) to
# 4-tuple (img_path(s), pid, camid, dsetid) by
# adding a dataset indicator "dsetid"
if len(train[0]) == 3:
train = [(*items, 0) for items in train]
if len(query[0]) == 3:
query = [(*items, 0) for items in query]
if len(gallery[0]) == 3:
gallery = [(*items, 0) for items in gallery]
self.train = train
self.query = query
self.gallery = gallery
self.transform = transform
self.k_tfm = k_tfm
self.mode = mode
self.combineall = combineall
self.verbose = verbose
self.num_train_pids = self.get_num_pids(self.train)
self.num_train_cams = self.get_num_cams(self.train)
self.num_datasets = self.get_num_datasets(self.train)
if self.combineall:
self.combine_all()
if self.mode == 'train':
self.data = self.train
elif self.mode == 'query':
self.data = self.query
elif self.mode == 'gallery':
self.data = self.gallery
else:
raise ValueError(
'Invalid mode. Got {}, but expected to be '
'one of [train | query | gallery]'.format(self.mode)
)
if self.verbose:
self.show_summary()
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
return len(self.data)
def __add__(self, other):
"""Adds two datasets together (only the train set)."""
train = copy.deepcopy(self.train)
for img_path, pid, camid, dsetid in other.train:
pid += self.num_train_pids
camid += self.num_train_cams
dsetid += self.num_datasets
train.append((img_path, pid, camid, dsetid))
###################################
# Note that
# 1. set verbose=False to avoid unnecessary print
# 2. set combineall=False because combineall would have been applied
# if it was True for a specific dataset; setting it to True will
# create new IDs that should have already been included
###################################
if isinstance(train[0][0], str):
return ImageDataset(
train,
self.query,
self.gallery,
transform=self.transform,
mode=self.mode,
combineall=False,
verbose=False
)
else:
return VideoDataset(
train,
self.query,
self.gallery,
transform=self.transform,
mode=self.mode,
combineall=False,
verbose=False,
seq_len=self.seq_len,
sample_method=self.sample_method
)
def __radd__(self, other):
"""Supports sum([dataset1, dataset2, dataset3])."""
if other == 0:
return self
else:
return self.__add__(other)
def get_num_pids(self, data):
"""Returns the number of training person identities.
Each tuple in data contains (img_path(s), pid, camid, dsetid).
"""
pids = set()
for items in data:
pid = items[1]
pids.add(pid)
return len(pids)
def get_num_cams(self, data):
"""Returns the number of training cameras.
Each tuple in data contains (img_path(s), pid, camid, dsetid).
"""
cams = set()
for items in data:
camid = items[2]
cams.add(camid)
return len(cams)
def get_num_datasets(self, data):
"""Returns the number of datasets included.
Each tuple in data contains (img_path(s), pid, camid, dsetid).
"""
dsets = set()
for items in data:
dsetid = items[3]
dsets.add(dsetid)
return len(dsets)
def show_summary(self):
"""Shows dataset statistics."""
pass
def combine_all(self):
"""Combines train, query and gallery in a dataset for training."""
if self._train_only:
return
combined = copy.deepcopy(self.train)
# relabel pids in gallery (query shares the same scope)
g_pids = set()
for items in self.gallery:
pid = items[1]
if pid in self._junk_pids:
continue
g_pids.add(pid)
pid2label = {pid: i for i, pid in enumerate(g_pids)}
def _combine_data(data):
for img_path, pid, camid, dsetid in data:
if pid in self._junk_pids:
continue
pid = pid2label[pid] + self.num_train_pids
combined.append((img_path, pid, camid, dsetid))
_combine_data(self.query)
_combine_data(self.gallery)
self.train = combined
self.num_train_pids = self.get_num_pids(self.train)
def download_dataset(self, dataset_dir, dataset_url):
"""Downloads and extracts dataset.
Args:
dataset_dir (str): dataset directory.
dataset_url (str): url to download dataset.
"""
if osp.exists(dataset_dir):
return
if dataset_url is None:
raise RuntimeError(
'{} dataset needs to be manually '
'prepared, please follow the '
'document to prepare this dataset'.format(
self.__class__.__name__
)
)
print('Creating directory "{}"'.format(dataset_dir))
mkdir_if_missing(dataset_dir)
fpath = osp.join(dataset_dir, osp.basename(dataset_url))
print(
'Downloading {} dataset to "{}"'.format(
self.__class__.__name__, dataset_dir
)
)
download_url(dataset_url, fpath)
print('Extracting "{}"'.format(fpath))
try:
tar = tarfile.open(fpath)
tar.extractall(path=dataset_dir)
tar.close()
except:
zip_ref = zipfile.ZipFile(fpath, 'r')
zip_ref.extractall(dataset_dir)
zip_ref.close()
print('{} dataset is ready'.format(self.__class__.__name__))
def check_before_run(self, required_files):
"""Checks if required files exist before going deeper.
Args:
required_files (str or list): string file name(s).
"""
if isinstance(required_files, str):
required_files = [required_files]
for fpath in required_files:
if not osp.exists(fpath):
raise RuntimeError('"{}" is not found'.format(fpath))
def __repr__(self):
num_train_pids = self.get_num_pids(self.train)
num_train_cams = self.get_num_cams(self.train)
num_query_pids = self.get_num_pids(self.query)
num_query_cams = self.get_num_cams(self.query)
num_gallery_pids = self.get_num_pids(self.gallery)
num_gallery_cams = self.get_num_cams(self.gallery)
msg = ' ----------------------------------------\n' \
' subset | # ids | # items | # cameras\n' \
' ----------------------------------------\n' \
' train | {:5d} | {:7d} | {:9d}\n' \
' query | {:5d} | {:7d} | {:9d}\n' \
' gallery | {:5d} | {:7d} | {:9d}\n' \
' ----------------------------------------\n' \
' items: images/tracklets for image/video dataset\n'.format(
num_train_pids, len(self.train), num_train_cams,
num_query_pids, len(self.query), num_query_cams,
num_gallery_pids, len(self.gallery), num_gallery_cams
)
return msg
def _transform_image(self, tfm, k_tfm, img0):
"""Transforms a raw image (img0) k_tfm times with
the transform function tfm.
"""
img_list = []
for k in range(k_tfm):
img_list.append(tfm(img0))
img = img_list
if len(img) == 1:
img = img[0]
return img
class ImageDataset(Dataset):
"""A base class representing ImageDataset.
All other image datasets should subclass it.
``__getitem__`` returns an image given index.
It will return ``img``, ``pid``, ``camid`` and ``img_path``
where ``img`` has shape (channel, height, width). As a result,
data in each batch has shape (batch_size, channel, height, width).
"""
def __init__(self, train, query, gallery, **kwargs):
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
def __getitem__(self, index):
img_path, pid, camid, dsetid = self.data[index]
img = read_image(img_path)
if self.transform is not None:
img = self._transform_image(self.transform, self.k_tfm, img)
item = {
'img': img,
'pid': pid,
'camid': camid,
'impath': img_path,
'dsetid': dsetid
}
return item
def show_summary(self):
num_train_pids = self.get_num_pids(self.train)
num_train_cams = self.get_num_cams(self.train)
num_query_pids = self.get_num_pids(self.query)
num_query_cams = self.get_num_cams(self.query)
num_gallery_pids = self.get_num_pids(self.gallery)
num_gallery_cams = self.get_num_cams(self.gallery)
print('=> Loaded {}'.format(self.__class__.__name__))
print(' ----------------------------------------')
print(' subset | # ids | # images | # cameras')
print(' ----------------------------------------')
print(
' train | {:5d} | {:8d} | {:9d}'.format(
num_train_pids, len(self.train), num_train_cams
)
)
print(
' query | {:5d} | {:8d} | {:9d}'.format(
num_query_pids, len(self.query), num_query_cams
)
)
print(
' gallery | {:5d} | {:8d} | {:9d}'.format(
num_gallery_pids, len(self.gallery), num_gallery_cams
)
)
print(' ----------------------------------------')
class VideoDataset(Dataset):
"""A base class representing VideoDataset.
All other video datasets should subclass it.
``__getitem__`` returns an image given index.
It will return ``imgs``, ``pid`` and ``camid``
where ``imgs`` has shape (seq_len, channel, height, width). As a result,
data in each batch has shape (batch_size, seq_len, channel, height, width).
"""
def __init__(
self,
train,
query,
gallery,
seq_len=15,
sample_method='evenly',
**kwargs
):
super(VideoDataset, self).__init__(train, query, gallery, **kwargs)
self.seq_len = seq_len
self.sample_method = sample_method
if self.transform is None:
raise RuntimeError('transform must not be None')
def __getitem__(self, index):
img_paths, pid, camid, dsetid = self.data[index]
num_imgs = len(img_paths)
if self.sample_method == 'random':
# Randomly samples seq_len images from a tracklet of length num_imgs,
# if num_imgs is smaller than seq_len, then replicates images
indices = np.arange(num_imgs)
replace = False if num_imgs >= self.seq_len else True
indices = np.random.choice(
indices, size=self.seq_len, replace=replace
)
# sort indices to keep temporal order (comment it to be order-agnostic)
indices = np.sort(indices)
elif self.sample_method == 'evenly':
# Evenly samples seq_len images from a tracklet
if num_imgs >= self.seq_len:
num_imgs -= num_imgs % self.seq_len
indices = np.arange(0, num_imgs, num_imgs / self.seq_len)
else:
# if num_imgs is smaller than seq_len, simply replicate the last image
# until the seq_len requirement is satisfied
indices = np.arange(0, num_imgs)
num_pads = self.seq_len - num_imgs
indices = np.concatenate(
[
indices,
np.ones(num_pads).astype(np.int32) * (num_imgs-1)
]
)
assert len(indices) == self.seq_len
elif self.sample_method == 'all':
# Samples all images in a tracklet. batch_size must be set to 1
indices = np.arange(num_imgs)
else:
raise ValueError(
'Unknown sample method: {}'.format(self.sample_method)
)
imgs = []
for index in indices:
img_path = img_paths[int(index)]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0) # img must be torch.Tensor
imgs.append(img)
imgs = torch.cat(imgs, dim=0)
item = {'img': imgs, 'pid': pid, 'camid': camid, 'dsetid': dsetid}
return item
def show_summary(self):
num_train_pids = self.get_num_pids(self.train)
num_train_cams = self.get_num_cams(self.train)
num_query_pids = self.get_num_pids(self.query)
num_query_cams = self.get_num_cams(self.query)
num_gallery_pids = self.get_num_pids(self.gallery)
num_gallery_cams = self.get_num_cams(self.gallery)
print('=> Loaded {}'.format(self.__class__.__name__))
print(' -------------------------------------------')
print(' subset | # ids | # tracklets | # cameras')
print(' -------------------------------------------')
print(
' train | {:5d} | {:11d} | {:9d}'.format(
num_train_pids, len(self.train), num_train_cams
)
)
print(
' query | {:5d} | {:11d} | {:9d}'.format(
num_query_pids, len(self.query), num_query_cams
)
)
print(
' gallery | {:5d} | {:11d} | {:9d}'.format(
num_gallery_pids, len(self.gallery), num_gallery_cams
)
)
print(' -------------------------------------------')