-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
28 lines (22 loc) · 912 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import argparse
import numpy as np
import mmcv
from CloudMatting.models.builder import builder_models
from CloudMatting.utils import utils
parse=argparse.ArgumentParser()
# parse.add_argument('--config_file',
# default=r'configs/vr_resnet50_inapinting_agr_cfg.py',type=str)
parse.add_argument('--config_file',default=r'configs/cloud_LSGAN_resnet50_cfg.py',type=str)
#
parse.add_argument('--checkpoints_path',default=None,type=str)
parse.add_argument('--log_path',default=None,type=str)
parse.add_argument('--with_matting',action='store_true')
if __name__=='__main__':
args = parse.parse_args()
print(args)
cfg = mmcv.Config.fromfile(args.config_file)
models=builder_models(**cfg['config'],with_matting=args.with_matting)
run_args={}
models.run_train_interface(checkpoint_path=args.checkpoints_path,
log_path=args.log_path)