-
Notifications
You must be signed in to change notification settings - Fork 16
/
train.py
54 lines (46 loc) · 1.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
'''
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
'''
import sys, shutil, os
from trainers.trainer_unet import train
from trainers.train_framework import TrainFramework
from models.unet import UnetModel
from models.losses import BCELoss, DiceBCELoss, DiceLoss, WeightedBCELoss
from options.train_options import TrainOptions
from data.dataloader import load_dataset
# parse options
opts = TrainOptions().parse()
# print options to help debugging
print(' '.join(sys.argv))
if opts.model == "unet":
model = UnetModel
else:
print("Option {} not supported. Available options: unet".format(
opts.model))
raise NotImplementedError
if opts.loss == "bce":
loss = BCELoss
elif opts.loss == "wbce":
loss = WeightedBCELoss
elif opts.loss == "dice":
loss = DiceLoss
elif opts.loss == "dice_bce":
loss = DiceBCELoss
else:
print("Option {} not supported. Available options: bce, mce, dice, dice_bce, multitask_bce".format(opts.loss))
raise NotImplementedError
frame = TrainFramework(
model(opts),
loss(),
opts
)
if opts.overwrite:
shutil.rmtree(opts.save_dir + "/" + opts.experiment_name + "/training")
os.makedirs(opts.save_dir + "/" + opts.experiment_name + "/training")
dataloaders = load_dataset(opts)
if opts.model == "unet":
_, train_history, val_history = train(frame, dataloaders, opts)
else:
print("Model {} not supported. Available options: unet, multitask_unet, multidecoder_unet".format(opts.model))
raise NotImplementedError