-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_decoder.py
80 lines (61 loc) · 2.05 KB
/
train_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import random as rn
from argparse import ArgumentParser
from os import mkdir
from os.path import exists
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping, ReduceLROnPlateau
seed = 0
os.environ["PYTHONHASHSEED"] = "0"
np.random.seed(seed)
rn.seed(seed)
tf.random.set_seed(seed)
from impl.architectures import simple_decoder
from impl.datagen import decoder_gen
if __name__ == "__main__":
parser = ArgumentParser(description="DeepArUco v2 marker decoding model trainer.")
parser.add_argument("source_dir", help="where to find source images")
parser.add_argument("run_name", help="where to store the resulting model")
args = parser.parse_args()
# Control paramters
batch_size = 32
epochs = 1000
patience = 20
reduce_after = 10
# Model
model = simple_decoder()
model.compile(loss="mae", optimizer="adam")
# Load dataset
train_csv = args.source_dir + "/train.csv"
train_df = pd.read_csv(train_csv)
train_generator = decoder_gen(train_df, args.source_dir, batch_size, True, True)
valid_csv = args.source_dir + "/valid.csv"
valid_df = pd.read_csv(valid_csv)
valid_generator = decoder_gen(valid_df, args.source_dir, batch_size, True, True)
# Callbacks
stop = EarlyStopping(
monitor="val_loss",
patience=patience,
verbose=True,
restore_best_weights=True,
min_delta=1e-4,
)
reduce_lr = ReduceLROnPlateau(monitor="val_loss", patience=reduce_after, factor=0.5)
# Training
if not exists("./models"):
mkdir("./models")
if not exists("./models/losses"):
mkdir("./models/losses")
model_name = "simple_decoder"
csv_logger = CSVLogger(f"./models/losses/loss_{args.run_name}.csv")
model.fit(
train_generator,
batch_size=batch_size,
epochs=epochs,
validation_data=valid_generator,
callbacks=[stop, reduce_lr, csv_logger],
verbose=True,
)
model.save(f"./models/{args.run_name}.h5")