Skip to content

Commit

Permalink
assert best possible recall > 0.9 before training
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 13, 2020
1 parent 19e68e8 commit 31f3310
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
5 changes: 4 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,17 @@ def train(hyp):
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.names = data_dict['names']

# class frequency
# Class frequency
labels = np.concatenate(dataset.labels, 0)
c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1.
# model._initialize_biases(cf.to(device))
plot_labels(labels)
tb_writer.add_histogram('classes', c, 0)

# Check anchors
check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'])

# Exponential moving average
ema = torch_utils.ModelEMA(model)

Expand Down
24 changes: 13 additions & 11 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,20 +291,22 @@ def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, r
self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
for x in self.img_files]

# Read image shapes (wh)
sp = path.replace('.txt', '') + '.shapes' # shapefile path
try:
with open(sp, 'r') as f: # read existing shapefile
s = [x.split() for x in f.read().splitlines()]
assert len(s) == n, 'Shapefile out of sync'
except:
s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
np.savetxt(sp, s, fmt='%g') # overwrites existing (if any)

self.shapes = np.array(s, dtype=np.float64)

# Rectangular Training https://github.com/ultralytics/yolov3/issues/232
if self.rect:
# Read image shapes (wh)
sp = path.replace('.txt', '') + '.shapes' # shapefile path
try:
with open(sp, 'r') as f: # read existing shapefile
s = [x.split() for x in f.read().splitlines()]
assert len(s) == n, 'Shapefile out of sync'
except:
s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
np.savetxt(sp, s, fmt='%g') # overwrites existing (if any)

# Sort by aspect ratio
s = np.array(s, dtype=np.float64)
s = self.shapes # wh
ar = s[:, 1] / s[:, 0] # aspect ratio
irect = ar.argsort()
self.img_files = [self.img_files[i] for i in irect]
Expand Down
13 changes: 13 additions & 0 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ def check_img_size(img_size, s=32):
return make_divisible(img_size, s) # nearest gs-multiple


def check_best_possible_recall(dataset, anchors, thr):
# Check best possible recall of dataset with current anchors
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(dataset.shapes, dataset.labels)])) # width-height
ratio = wh[:, None] / anchors.view(-1, 2)[None] # ratio
m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio
bpr = (m.min(1)[0] < thr).float().mean() # best possible recall
mr = (m < thr).float().mean() # match ratio
print(('Label width-height:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr


def make_divisible(x, divisor):
# Returns x evenly divisble by divisor
return math.ceil(x / divisor) * divisor
Expand Down

0 comments on commit 31f3310

Please sign in to comment.