-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
56 lines (44 loc) · 1.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
Utility functions.
"""
import os
import random
import numpy as np
import torch
cudnn_deterministic = True
def seed_everything(seed=0):
"""
Fixing all random seeds
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = cudnn_deterministic
def compute_topk_acc(pred, targets, topk):
"""
Computing top-k accuracy given prediction and target vectors.
Args:
pred: Network prediction
targets: Ground truth labels
topk: k value
"""
topk = min(topk, pred.shape[1])
_, pred = pred.topk(topk, 1, True, True)
pred = pred.t()
correct = pred.eq(targets.view(1, -1).expand_as(pred))
hits_tag = correct[:topk].reshape(-1).float().sum(0)
return hits_tag
def calculate_metrics(outputs, targets):
"""
Computing top-1 and top-5 accuracy metrics.
Args:
outputs: Network outputs list
targets: Ground truth labels
"""
pred = outputs
# Top-k prediction for TAg
hits_tag_top5 = compute_topk_acc(pred, targets, 5)
hits_tag_top1 = compute_topk_acc(pred, targets, 1)
return hits_tag_top5.item(), hits_tag_top1.item()