-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
119 lines (97 loc) · 4.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from metrics.classification.confusion_matrix import ConfusionMatrix
from metrics.classification.F1_score import ClassificationF1Score
from metrics.classification.classification_accuracy import ClassificationAccuracyMetric
from models.timm_models import BaseTimmModel
from utils.getter import *
import argparse
import pprint
set_seed()
def main(config, args):
# Google Colab only use 1 GPU
os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu_devices
num_gpus = len(config.gpu_devices.split(','))
device = torch.device('cuda' if args.cuda is not None else 'cpu')
pprint.PrettyPrinter(indent=2).pprint(vars(config))
trainset, valset, train_loader, val_loader = get_dataset_and_dataloader(
config)
net = BaseTimmModel(n_classes=trainset.num_classes,
model_arch=config.model_arch, head='mlp', syncBN=config.syncBN)
if args.saved_path is not None:
args.saved_path = os.path.join(args.saved_path, config.project_name)
if args.log_path is not None:
args.log_path = os.path.join(args.log_path, config.project_name)
metric = [
ClassificationAccuracyMetric(),
ClassificationF1Score(n_classes=trainset.num_classes),
ConfusionMatrix(trainset.classes),
]
optimizer, optimizer_params = get_lr_policy(config.lr_policy)
if config.mixed_precision:
scaler = NativeScaler()
else:
scaler = None
model = Classifier(
model=net,
metrics=metric,
scaler=scaler,
criterion=nn.CrossEntropyLoss(),
optimizer=optimizer,
optim_params=optimizer_params,
device=device)
if args.resume is not None:
load(model, args.resume)
start_epoch, start_iter, best_value = get_epoch_iters(args.resume)
else:
print('Not resume. Initialize weights')
start_epoch, start_iter, best_value = 0, 0, 0.0
scheduler, step_per_epoch = get_lr_scheduler(
model.optimizer,
lr_config=config.lr_scheduler,
num_epochs=config.num_epochs)
trainer = Trainer(config,
model,
train_loader,
val_loader,
checkpoint=CheckPoint(
save_per_iter=args.save_interval, path=args.saved_path),
best_value=best_value,
logger=Logger(log_dir=args.log_path),
scheduler=scheduler,
visualize_when_val=args.gradcam_visualization,
num_evaluate_per_epoch=args.val_interval,
step_per_epoch=step_per_epoch)
print("########## DATASET INFO ##########")
print("Trainset: ")
print(trainset)
print("Valset: ")
print(valset)
print()
print(trainer)
print()
print(config)
print(f'Training with {num_gpus} gpu(s)')
print(f"Start training at [{start_epoch}|{start_iter}]")
print(f"Current best MAP: {best_value}")
trainer.fit(start_epoch=start_epoch, start_iter=start_iter,
num_epochs=config.num_epochs, print_per_iter=args.print_per_iter)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Training custom model')
parser.add_argument('--config', default='object_retrieval', type=str,
help='project file that contains parameters')
parser.add_argument('--print_per_iter', type=int,
default=300, help='Number of iteration to print')
parser.add_argument('--val_interval', type=int, default=2,
help='Number of epoches between valing phases')
parser.add_argument('--gradcam_visualization', action='store_true',
help='whether to visualize box to ./sample when validating (for debug), default=off')
parser.add_argument('--save_interval', type=int,
default=1000, help='Number of steps between saving')
parser.add_argument('--log_path', type=str, default='loggers/runs')
parser.add_argument('--resume', type=str, default=None,
help='whether to load weights from a checkpoint, set None to initialize')
parser.add_argument('--saved_path', type=str, default='./weights')
# parser.add_argument('--freeze_backbone', action='store_true',
# help='whether to freeze the backbone')
args = parser.parse_args()
config = Config(os.path.join('configs', args.config + '.yaml'))
main(config, args)