-
Notifications
You must be signed in to change notification settings - Fork 25
/
samplers.py
32 lines (24 loc) · 1.1 KB
/
samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
from torch.utils.data.sampler import Sampler
class MarketSampler(Sampler):
def __init__(self, labels, batch_size):
self.labels = np.array(labels)
self.labels_unique = np.unique(labels)
self.batch_size = batch_size
def __iter__(self):
for i in range(self.__len__()):
labels_in_batch = set()
inds = np.array([], dtype=np.int)
while inds.shape[0] < self.batch_size:
sample_label = np.random.choice(self.labels_unique)
if sample_label in labels_in_batch:
continue
labels_in_batch.add(sample_label)
subsample_size = np.random.choice(range(5, 11))
sample_label_ids = np.argwhere(np.in1d(self.labels, sample_label)).reshape(-1)
subsample = np.random.permutation(sample_label_ids)[:subsample_size]
inds = np.append(inds, subsample)
inds = inds[:self.batch_size]
yield list(inds)
def __len__(self):
return len(self.labels) // self.batch_size