forked from google/flax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
422 lines (349 loc) · 14.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sequence Tagging example.
This script trains a Transformer on the Universal dependency dataset.
"""
import functools
import itertools
import os
import time
from absl import app
from absl import flags
from absl import logging
from flax import jax_utils
from flax import nn
from flax import optim
import input_pipeline
import models
from flax.metrics import tensorboard
from flax.training import common_utils
import jax
from jax import random
import jax.nn
import jax.numpy as jnp
import numpy as np
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
flags.DEFINE_string('model_dir', default='', help=('Directory for model data.'))
flags.DEFINE_string('experiment', default='xpos', help=('Experiment name.'))
flags.DEFINE_integer(
'batch_size', default=64, help=('Batch size for training.'))
flags.DEFINE_integer(
'eval_frequency',
default=100,
help=('Frequency of eval during training, e.g. every 1000 steps.'))
flags.DEFINE_integer(
'num_train_steps', default=75000, help=('Number of train steps.'))
flags.DEFINE_integer(
'num_eval_steps',
default=-1,
help=('Number of evaluation steps. If -1 use the whole evaluation set.'))
flags.DEFINE_float('learning_rate', default=0.05, help=('Learning rate.'))
flags.DEFINE_float(
'weight_decay',
default=1e-1,
help=('Decay factor for AdamW style weight decay.'))
flags.DEFINE_integer('max_length', default=256,
help=('Maximum length of examples.'))
flags.DEFINE_integer(
'random_seed', default=0, help=('Integer for PRNG random seed.'))
flags.DEFINE_string('train', default='', help=('Path to training data.'))
flags.DEFINE_string('dev', default='', help=('Path to development data.'))
@functools.partial(jax.jit, static_argnums=(1, 2))
def create_model(key, input_shape, model_kwargs):
module = models.Transformer.partial(train=False, **model_kwargs)
_, initial_params = module.init_by_shape(key, [(input_shape, jnp.float32)])
model = nn.Model(module, initial_params)
return model
def create_optimizer(model, learning_rate):
optimizer_def = optim.Adam(
learning_rate,
beta1=0.9,
beta2=0.98,
eps=1e-9,
weight_decay=FLAGS.weight_decay)
optimizer = optimizer_def.create(model)
optimizer = jax_utils.replicate(optimizer)
return optimizer
def create_learning_rate_scheduler(
factors='constant * linear_warmup * rsqrt_decay',
base_learning_rate=0.5,
warmup_steps=8000,
decay_factor=0.5,
steps_per_decay=20000,
steps_per_cycle=100000):
"""creates learning rate schedule.
Interprets factors in the factors string which can consist of:
* constant: interpreted as the constant value,
* linear_warmup: interpreted as linear warmup until warmup_steps,
* rsqrt_decay: divide by square root of max(step, warmup_steps)
* decay_every: Every k steps decay the learning rate by decay_factor.
* cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
Args:
factors: a string with factors separated by '*' that defines the schedule.
base_learning_rate: float, the starting constant for the lr schedule.
warmup_steps: how many steps to warm up for in the warmup schedule.
decay_factor: The amount to decay the learning rate by.
steps_per_decay: How often to decay the learning rate.
steps_per_cycle: Steps per cycle when using cosine decay.
Returns:
a function learning_rate(step): float -> {'learning_rate': float}, the
step-dependent lr.
"""
factors = [n.strip() for n in factors.split('*')]
def step_fn(step):
"""Step to learning rate function."""
ret = 1.0
for name in factors:
if name == 'constant':
ret *= base_learning_rate
elif name == 'linear_warmup':
ret *= jnp.minimum(1.0, step / warmup_steps)
elif name == 'rsqrt_decay':
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
elif name == 'rsqrt_normalized_decay':
ret *= jnp.sqrt(warmup_steps)
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
elif name == 'decay_every':
ret *= (decay_factor**(step // steps_per_decay))
elif name == 'cosine_decay':
progress = jnp.maximum(0.0,
(step - warmup_steps) / float(steps_per_cycle))
ret *= jnp.maximum(0.0,
0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
else:
raise ValueError('Unknown factor %s.' % name)
return jnp.asarray(ret, dtype=jnp.float32)
return step_fn
def compute_weighted_cross_entropy(logits, targets, weights=None):
"""Compute weighted cross entropy and entropy for log probs and targets.
Args:
logits: [batch, length, num_classes] float array.
targets: categorical targets [batch, length] int array.
weights: None or array of shape [batch x length]
Returns:
Tuple of scalar loss and batch normalizing factor.
"""
if logits.ndim != targets.ndim + 1:
raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
(str(logits.shape), str(targets.shape)))
onehot_targets = common_utils.onehot(targets, logits.shape[-1])
loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
normalizing_factor = onehot_targets.sum()
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def compute_weighted_accuracy(logits, targets, weights=None):
"""Compute weighted accuracy for log probs and targets.
Args:
logits: [batch, length, num_classes] float array.
targets: categorical targets [batch, length] int array.
weights: None or array of shape [batch x length]
Returns:
Tuple of scalar accuracy and batch normalizing factor.
"""
if logits.ndim != targets.ndim + 1:
raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
(str(logits.shape), str(targets.shape)))
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
normalizing_factor = jnp.prod(logits.shape[:-1])
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def compute_metrics(logits, labels, weights):
"""Compute summary metrics."""
loss, weight_sum = compute_weighted_cross_entropy(logits, labels, weights)
acc, _ = compute_weighted_accuracy(logits, labels, weights)
metrics = {
'loss': loss,
'accuracy': acc,
'denominator': weight_sum,
}
metrics = np.sum(metrics, -1)
return metrics
def train_step(optimizer, batch, learning_rate_fn, dropout_rng=None):
"""Perform a single training step."""
train_keys = ['inputs', 'targets']
(inputs, targets) = [batch.get(k, None) for k in train_keys]
weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32)
# It's very important to handle PRNG splitting inside the top pmap, rather
# than handling it outside in the training loop - doing the latter can add
# bad stalls to the input data transfer.
dropout_rng, new_dropout_rng = random.split(dropout_rng)
def loss_fn(model):
"""Loss function used for training."""
with nn.stochastic(dropout_rng):
logits = model(inputs, train=True)
loss, weight_sum = compute_weighted_cross_entropy(logits, targets, weights)
mean_loss = loss / weight_sum
return mean_loss, logits
step = optimizer.state.step
lr = learning_rate_fn(step)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grad = grad_fn(optimizer.target)
grad = jax.lax.pmean(grad, 'batch')
new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
metrics = compute_metrics(logits, targets, weights)
metrics['learning_rate'] = lr
return new_optimizer, metrics, new_dropout_rng
def eval_step(model, batch):
"""Calculate evaluation metrics on a batch."""
inputs, targets = batch['inputs'], batch['targets']
weights = jnp.where(targets > 0, 1.0, 0.0)
logits = model(inputs, train=False)
return compute_metrics(logits, targets, weights)
def pad_examples(x, desired_batch_size):
"""Expand batch to desired size by zeros with the shape of last slice."""
batch_pad = desired_batch_size - x.shape[0]
# Padding with zeros to avoid that they get counted in compute_metrics.
return np.concatenate([x, np.tile(np.zeros_like(x[-1]), (batch_pad, 1))])
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
tf.enable_v2_behavior()
batch_size = FLAGS.batch_size
learning_rate = FLAGS.learning_rate
num_train_steps = FLAGS.num_train_steps
num_eval_steps = FLAGS.num_eval_steps
eval_freq = FLAGS.eval_frequency
max_length = FLAGS.max_length
random_seed = FLAGS.random_seed
if not FLAGS.dev:
raise app.UsageError('Please provide path to dev set.')
if not FLAGS.train:
raise app.UsageError('Please provide path to training set.')
parameter_path = os.path.join(FLAGS.model_dir, FLAGS.experiment + '.params')
if jax.host_id() == 0:
train_summary_writer = tensorboard.SummaryWriter(
os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train'))
eval_summary_writer = tensorboard.SummaryWriter(
os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval'))
if batch_size % jax.device_count() > 0:
raise ValueError('Batch size must be divisible by the number of devices')
device_batch_size = batch_size // jax.device_count()
# create the training and development dataset
vocabs = input_pipeline.create_vocabs(FLAGS.train)
attributes_input = [input_pipeline.CoNLLAttributes.FORM]
attributes_target = [input_pipeline.CoNLLAttributes.XPOS]
train_ds = input_pipeline.sentence_dataset_dict(
FLAGS.train,
vocabs,
attributes_input,
attributes_target,
batch_size=batch_size,
bucket_size=max_length)
eval_ds = input_pipeline.sentence_dataset_dict(
FLAGS.dev,
vocabs,
attributes_input,
attributes_target,
batch_size=batch_size,
bucket_size=max_length,
repeat=1)
train_iter = iter(train_ds)
bs = device_batch_size * jax.device_count()
rng = random.PRNGKey(random_seed)
rng, init_rng = random.split(rng)
input_shape = (bs, max_length)
transformer_kwargs = {
'vocab_size': len(vocabs['forms']),
'output_vocab_size': len(vocabs['xpos']),
'emb_dim': 512,
'num_heads': 8,
'num_layers': 6,
'qkv_dim': 512,
'mlp_dim': 2048,
'max_len': max_length,
}
model = create_model(init_rng, tuple(input_shape), transformer_kwargs)
optimizer = create_optimizer(model, learning_rate)
del model # don't keep a copy of the initial model
learning_rate_fn = create_learning_rate_scheduler(
base_learning_rate=learning_rate)
p_train_step = jax.pmap(
functools.partial(train_step, learning_rate_fn=learning_rate_fn),
axis_name='batch')
p_eval_step = jax.pmap(eval_step, axis_name='batch')
# We init the first set of dropout PRNG keys, but update it afterwards inside
# the main pmap'd training update for performance.
dropout_rngs = random.split(rng, jax.local_device_count())
metrics_all = []
tick = time.time()
best_dev_score = 0
for step, batch in zip(range(num_train_steps), train_iter):
batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access
optimizer, metrics, dropout_rngs = p_train_step(
optimizer, batch, dropout_rng=dropout_rngs)
metrics_all.append(metrics)
if (step + 1) % eval_freq == 0:
metrics_all = common_utils.get_metrics(metrics_all)
lr = metrics_all.pop('learning_rate').mean()
metrics_sums = jax.tree_map(jnp.sum, metrics_all)
denominator = metrics_sums.pop('denominator')
summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop
summary['learning_rate'] = lr
# Calculate (clipped) perplexity after averaging log-perplexities:
summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
logging.info('train in step: %d, loss: %.4f', step, summary['loss'])
if jax.host_id() == 0:
tock = time.time()
steps_per_sec = eval_freq / (tock - tick)
tick = tock
train_summary_writer.scalar('steps per second', steps_per_sec, step)
for key, val in summary.items():
train_summary_writer.scalar(key, val, step)
train_summary_writer.flush()
# reset metric accumulation for next evaluation cycle.
metrics_all = []
eval_metrics = []
eval_iter = iter(eval_ds)
if num_eval_steps == -1:
num_iter = itertools.repeat(1)
else:
num_iter = range(num_eval_steps)
for _, eval_batch in zip(num_iter, eval_iter):
eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access
# Handle final odd-sized batch by padding instead of dropping it.
cur_pred_batch_size = eval_batch['inputs'].shape[0]
if cur_pred_batch_size != batch_size:
logging.info('Uneven batch size %d.', cur_pred_batch_size)
eval_batch = jax.tree_map(
lambda x: pad_examples(x, batch_size), eval_batch)
eval_batch = common_utils.shard(eval_batch)
metrics = p_eval_step(optimizer.target, eval_batch)
eval_metrics.append(metrics)
eval_metrics = common_utils.get_metrics(eval_metrics)
eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
eval_denominator = eval_metrics_sums.pop('denominator')
eval_summary = jax.tree_map(
lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop
eval_metrics_sums)
# Calculate (clipped) perplexity after averaging log-perplexities:
eval_summary['perplexity'] = jnp.clip(
jnp.exp(eval_summary['loss']), a_max=1.0e4)
logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step,
eval_summary['loss'], eval_summary['accuracy'])
if best_dev_score < eval_summary['accuracy']:
best_dev_score = eval_summary['accuracy']
# TODO: save model.
eval_summary['best_dev_score'] = best_dev_score
logging.info('best development model score %.4f', best_dev_score)
if jax.host_id() == 0:
for key, val in eval_summary.items():
eval_summary_writer.scalar(key, val, step)
eval_summary_writer.flush()
if __name__ == '__main__':
app.run(main)