-
Notifications
You must be signed in to change notification settings - Fork 4
/
pretrain.py
132 lines (112 loc) · 5.09 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright 2022 Huawei Technologies Co., Ltd
# Copyright 2022 Aerospace Information Research Institute,
# Chinese Academy of Sciences.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""pretrain of ringmo"""
import os
import argparse
import aicc_tools as ac
from mindspore.train.model import Model
from ringmo_framework.lr import build_lr
from ringmo_framework.arch import build_model
from ringmo_framework.optim import build_optim
from ringmo_framework.datasets import build_dataset
from ringmo_framework.trainer import build_wrapper
from ringmo_framework.parallel_config import build_parallel_config
from ringmo_framework.tools.helper import count_params
from ringmo_framework.monitors.callback import build_pretrain_callback
from ringmo_framework.tools.helper import build_context, str2bool
from ringmo_framework.tools.load_ckpt import load_ckpt
from register.config import RingMoConfig, ActionDict
@ac.aicc_monitor
def main(args):
# init context
cfts, profile_cb = build_context(args)
# build dataset
args.logger.info(".........Build Dataset..........")
args.pretrain_dataset.data_path = cfts.get_dataset(args.pretrain_dataset.data_path)
dataset = build_dataset(args)
data_size = dataset.get_dataset_size()
new_epochs = args.train_config.epoch
if args.train_config.per_epoch_size and args.train_config.sink_mode:
new_epochs = int((data_size / args.train_config.per_epoch_size) * new_epochs)
else:
args.train_config.per_epoch_size = data_size
args.data_size = data_size
args.logger.info("Will be Training epochs:{}, sink_size:{}".format(
new_epochs, args.train_config.per_epoch_size))
args.logger.info("Create training dataset finish, data size:{}".format(data_size))
# build context config
args.logger.info(".........Build context config..........")
build_parallel_config(args)
args.logger.info("context config is:{}".format(args.parallel_config))
args.logger.info("moe config is:{}".format(args.moe_config))
# build net
args.logger.info(".........Build Net..........")
net = build_model(args)
args.logger.info("网络参数量:{} M.".format(count_params(net)))
# build lr
args.logger.info(".........Build LR Schedule..........")
lr_schedule = build_lr(args)
# define optimizer
args.logger.info(".........Build Optimizer..........")
optimizer = build_optim(args, net, lr_schedule, args.logger)
# define model
args.logger.info(".........Build Train Model..........")
train_model = build_wrapper(args, net, optimizer, log=args.logger)
args.logger.info("模型参数量:{} M.".format(count_params(train_model)))
# define Model and begin training
args.logger.info(".........Starting Init Train Model..........")
model = Model(train_model)
# resume ckpt
load_ckpt(args, cfts, net, model, train_model, dataset, new_epochs)
# define callback
callback = build_pretrain_callback(args, cfts)
if args.profile:
callback.append(profile_cb)
args.logger.info(".........Starting Training Model..........")
model.train(new_epochs, dataset, callbacks=callback,
dataset_sink_mode=args.train_config.sink_mode,
sink_size=args.train_config.per_epoch_size)
if __name__ == "__main__":
work_path = os.path.dirname(os.path.abspath(__file__))
parser = argparse.ArgumentParser()
parser.add_argument(
'--config',
default=os.path.join(work_path, "config path"),
help='YAML config files')
parser.add_argument('--device_id', default=None, type=int, help='device id')
parser.add_argument('--seed', default=None, type=int, help='random seed')
parser.add_argument('--use_parallel', default=None, type=str2bool, help='whether use parallel mode')
parser.add_argument('--profile', default=None, type=str2bool, help='whether use profile analysis')
parser.add_argument(
'--options',
nargs='+',
action=ActionDict,
help='override some settings in the used config, the key-value pair'
'in xxx=yyy format will be merged into config file')
args_ = parser.parse_args()
config = RingMoConfig(args_.config)
if args_.device_id is not None:
config.context.device_id = args_.device_id
if args_.seed is not None:
config.seed = args_.seed
if args_.use_parallel is not None:
config.use_parallel = args_.use_parallel
if args_.profile is not None:
config.profile = args_.profile
if args_.options is not None:
config.merge_from_dict(args_.options)
main(config)