-
Notifications
You must be signed in to change notification settings - Fork 11
/
train.py
62 lines (52 loc) · 1.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow as tf
import tensorflow_estimator as tf_estimator
import json
import os
import gpt2_estimator
DEVICE = ["/gpu:0", "/gpu:1"]
def train_gpt2(
model_dir='models/gpt2',
pretrained_path='models/117M',
steps=100000,
batch_size=2,
num_gpu=1,
learning_rate=0.0001):
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
mirrored_strategy = tf.distribute.MirroredStrategy(
devices=DEVICE[:num_gpu])
learning_rate = learning_rate*1.5**num_gpu
session_config = tf.compat.v1.ConfigProto(
allow_soft_placement=True)
session_config.gpu_options.allow_growth = True
config = tf_estimator.estimator.RunConfig(
session_config=session_config,
train_distribute=mirrored_strategy,
eval_distribute=mirrored_strategy,
log_step_count_steps=5)
gpt2_model_fn = gpt2_estimator.get_gpt2_model_fn(
accumulate_gradients=5,
learning_rate=learning_rate,
length=512,
batch_size=batch_size,
temperature=0.7,
top_k=0
)
hparams = gpt2_estimator.default_hparams()
with open(os.path.join(pretrained_path, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
estimator = tf_estimator.estimator.Estimator(
gpt2_model_fn,
model_dir=model_dir,
params=hparams,
config=config)
restore_hook = gpt2_estimator.RestoreCheckpointHook(pretrained_path)
estimator.train(
lambda: gpt2_estimator.train_input_fn(batch_size=batch_size), max_steps=steps, hooks=[restore_hook])
# keep as an example
# pred = estimator.predict(
# lambda: gpt2_estimator.predict_input_fn(
# 'i am sick', batch_size=batch_size)
# )
if __name__ == "__main__":
train_gpt2(steps=5000000)