-
Notifications
You must be signed in to change notification settings - Fork 9
/
train.py
83 lines (66 loc) · 2.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import numpy as np
import os, sys
os.environ['TORCH_MODEL_ZOO'] = '/mnt/projects/counting/pretrained/resnet'
import models, datasets, metrics
import utils as ut
import tqdm, time
import pandas as pd
def main():
exp_dict = {"model":"MRCNN",
"max_epochs":10,
"batch_size":1}
model = models.mrcnn.MRCNN(exp_dict).cuda()
path_base = "checkpoints"
# model_state_dict = torch.load(path_base + "/model_state_dict.pth")
# model.load_state_dict(model_state_dict)
train_set = datasets.pascal2012.Pascal2012(split="train", exp_dict=exp_dict,
root="/mnt/datasets/public/issam/VOCdevkit/VOC2012/")
train_loader = torch.utils.data.DataLoader(
train_set,
collate_fn=train_set.collate_fn,
batch_size=1,
shuffle=True,
num_workers=1,
pin_memory=True)
# Main loop
model.history = {"score_list": []}
for e in range(exp_dict["max_epochs"]):
# Train for one epoch
score_dict = train_epoch(model, train_loader)
score_dict["epoch"] = e
# Update history
model.history["score_list"] += [score_dict]
# Report
results_df = pd.DataFrame(model.history["score_list"]).tail()
print(results_df[["epoch", "train_loss"]])
# Experiment completed
torch.save("./checkpoints/checkpoint.pth", {"model_state_dict": model.state_dict(),
"history": model.history})
def train_epoch(model, train_loader):
"""Trainer."""
model.train()
# Init variables
n_batches = len(train_loader)
loss_sum = 0.
s_time = time.time()
# Start training
pbar = tqdm.tqdm(total=n_batches)
for i, batch in enumerate(train_loader):
# Compute loss
loss = model.train_step(batch)
if np.isnan(loss):
raise ValueError('loss=NaN ...')
loss_sum += loss
pbar.set_description("Training loss: %.4f" % (loss_sum / (i + 1)))
pbar.update(1)
pbar.close()
# Update history
e_time = time.time()
score_dict = {
"train_loss": loss_sum / n_batches,
"train_time_taken": e_time - s_time,
}
return score_dict
if __name__ == "__main__":
main()