Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Script] Minor enhancement for bert pre-training script #1121

Merged
merged 4 commits into from
Jan 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/bert/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def step(self, batch_size, max_norm=None):
else:
overflow = is_finite.asscalar() < 1
if not overflow:
step_size = step_size.asscalar()
self.fp32_trainer.update(step_size)
else:
# TODO(haibin) optimize the performance when max_norm is not present
Expand Down
22 changes: 12 additions & 10 deletions scripts/bert/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
help='Number of batches for gradient accumulation. '
'total_batch_size = batch_size_per_worker * num_worker * accumulate.')
parser.add_argument('--num_steps', type=int, default=20, help='Number of optimization steps')
parser.add_argument('--optimizer', type=str, default='bertadam',
help='The optimization algorithm')
parser.add_argument('--start_step', type=int, default=0,
help='Start optimization step from the checkpoint.')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
Expand Down Expand Up @@ -227,7 +229,7 @@ def train(data_train, data_eval, model):
mlm_metric.reset()
nsp_metric.reset()

logging.debug('Creating distributed trainer...')
logging.info('Creating distributed trainer...')
lr = args.lr
optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
if args.dtype == 'float16':
Expand All @@ -241,9 +243,9 @@ def train(data_train, data_eval, model):

# backend specific implementation
if backend == 'horovod':
trainer = hvd.DistributedTrainer(param_dict, 'bertadam', optim_params)
trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params)
else:
trainer = mx.gluon.Trainer(param_dict, 'bertadam', optim_params,
trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params,
update_on_kvstore=False)
fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale,
loss_scaler_params=loss_scale_param)
Expand Down Expand Up @@ -273,7 +275,7 @@ def train(data_train, data_eval, model):
batch_num = 0
step_num = args.start_step

logging.debug('Training started')
logging.info('Training started')

# create dummy data loader if needed
parallel_model = DataParallelBERT(model, trainer=fp16_trainer)
Expand Down Expand Up @@ -362,11 +364,11 @@ def train(data_train, data_eval, model):
save_states(step_num, trainer, args.ckpt_dir, local_rank)
if local_rank == 0:
save_parameters(step_num, model.bert, args.ckpt_dir)
if (step_num + 1) % args.eval_interval == 0 and data_eval:
# eval data is always based on a fixed npz file.
dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval,
1, False, 1, vocab)
evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype)
if (step_num + 1) % args.eval_interval == 0 and data_eval:
# eval data is always based on a fixed npz file.
dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval,
1, False, 1, vocab)
evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype)

batch_num += 1

Expand Down Expand Up @@ -395,7 +397,7 @@ def train(data_train, data_eval, model):
dataset_name, vocab, args.dtype,
ckpt_dir=args.ckpt_dir,
start_step=args.start_step)
logging.debug('Model created')
logging.info('Model created')
data_eval = args.data_eval

if args.raw:
Expand Down
6 changes: 4 additions & 2 deletions scripts/tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def test_bert_embedding(use_pretrained):
@pytest.mark.remote_required
@pytest.mark.integration
@pytest.mark.parametrize('backend', ['horovod', 'device'])
def test_bert_pretrain(backend):
@pytest.mark.parametrize('optimizer', ['bertadam', 'lamb'])
def test_bert_pretrain(backend, optimizer):
# test data creation
process = subprocess.check_call([sys.executable, './scripts/bert/create_pretraining_data.py',
'--input_file', './scripts/bert/sample_text.txt',
Expand All @@ -249,7 +250,8 @@ def test_bert_pretrain(backend):
'--ckpt_dir', './test/bert/ckpt',
'--num_steps', '20', '--num_buckets', '1',
'--pretrained',
'--comm_backend', backend]
'--comm_backend', backend,
'--optimizer', optimizer]

if backend == 'device':
arguments += ['--gpus', '0']
Expand Down