Skip to content

Commit

Permalink
refactor training loop (#2336)
Browse files Browse the repository at this point in the history
* refactoring training epoch

* refactored training epoch

* refactored training epoch

* refactored training epoch

* refactored training epoch

* refactored training epoch

* fixes slurm weights saving

* fixes slurm weights saving
  • Loading branch information
williamFalcon committed Jun 24, 2020
1 parent c09b2ff commit 598f514
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 93 deletions.
188 changes: 115 additions & 73 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,16 @@ def train(self):

self.run_training_teardown()

def run_training_epoch(self):
def prepare_train_loop_dataloader(self, train_dataloader):
# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
device = xm.xla_device(self.tpu_id)
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)

# get model
model = self.get_model()
return train_dataloader

def run_on_epoch_start_hook(self, model):
# Epoch start events
with self.profiler.profile('on_epoch_start'):
# callbacks
Expand All @@ -429,17 +434,19 @@ def run_training_epoch(self):
if self.is_function_implemented('on_epoch_start'):
model.on_epoch_start()

# track local dataloader so TPU can wrap each epoch
train_dataloader = self.train_dataloader
def run_training_epoch(self):

# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
device = xm.xla_device(self.tpu_id)
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)
# get model
model = self.get_model()

# Epoch start events
self.run_on_epoch_start_hook(model)

# modify dataloader if needed (ddp, etc...)
train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader)

# bookkeeping
outputs = []
epoch_output = []

# run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
Expand All @@ -450,63 +457,41 @@ def run_training_epoch(self):
break

self.batch_idx = batch_idx

model.global_step = self.global_step

# ---------------
# RUN TRAIN STEP
# ---------------
_outputs = self.run_training_batch(batch, batch_idx)
batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs
# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
batch_output = self.run_training_batch(batch, batch_idx)

# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
if self.is_overridden('training_epoch_end', model=self.get_model()):
outputs.append(batch_output)
epoch_output.append(batch_output.training_step_output_for_epoch_end)

# update LR schedulers
self.update_train_loop_lr_schedulers()

# when returning -1 from train_step, we end epoch early
early_stop_epoch = batch_result == -1

# TODO: consolidate all actions that need to take place only after
# self.accumulate_grad_batches steps (optimizer step, lr update, global step increment)
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
# update lr
self.update_learning_rates(interval='step')

# ---------------
# RUN VAL STEP
# ---------------
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
can_check_val = not self.disable_validation and can_check_epoch
should_check_val = is_val_check_batch or early_stop_epoch
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and should_check_val

# ---------------
# CHECKPOINTING, EARLY STOPPING
# ---------------
# fast_dev_run always forces val checking after train batch
if self.fast_dev_run or should_check_val:
self.run_evaluation(test_mode=self.testing)
self.call_checkpoint_callback()

# when logs should be saved
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
if self.is_global_zero and self.logger is not None:
self.logger.save()

# when metrics should be logged
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
if should_log_metrics or self.fast_dev_run:
# logs user requested information to logger
self.log_metrics(batch_step_metrics, grad_norm_dic)
early_stop_epoch = batch_output.signal == -1

# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
should_check_val = self.check_validation_in_train_loop(batch_idx, early_stop_epoch, is_last_batch)

# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
# -----------------------------------------
self.save_loggers_in_training_loop(batch_idx, early_stop_epoch)

# -----------------------------------------
# SAVE METRICS TO LOGGERS
# -----------------------------------------
self.save_train_loop_metrics_to_loggers(batch_idx, early_stop_epoch, batch_output)

# progress global step according to grads progress
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
self.global_step += 1
self.total_batch_idx += 1
self.increment_accumulated_grad_global_step()

# max steps reached, end training
if self.max_steps is not None and self.max_steps == self.global_step:
Expand All @@ -518,13 +503,36 @@ def run_training_epoch(self):
if early_stop_epoch or self.fast_dev_run:
break

if self.use_horovod:
hvd.join(hvd.local_rank() if self.on_gpu else -1)
# let ddp devices catch up when using horovod
self.sync_horovod()

# process epoch outputs
self.run_training_epoch_end(epoch_output)

# when no val loop is present or fast-dev-run still need to call checkpoints
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
self.call_checkpoint_callback()

# epoch end hook
self.run_on_epoch_end_hook(model)

def update_train_loop_lr_schedulers(self):
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
# update lr
self.update_learning_rates(interval='step')

def run_on_epoch_end_hook(self, model):
with self.profiler.profile('on_epoch_end'):
# callbacks
self.on_epoch_end()
# model hooks
if self.is_function_implemented('on_epoch_end'):
model.on_epoch_end()

def run_training_epoch_end(self, epoch_output):
model = self.get_model()
if self.is_overridden('training_epoch_end', model=model):
epoch_output = model.training_epoch_end(outputs)
epoch_output = model.training_epoch_end(epoch_output)
_processed_outputs = self.process_output(epoch_output)
log_epoch_metrics = _processed_outputs[2]
callback_epoch_metrics = _processed_outputs[3]
Expand All @@ -538,17 +546,45 @@ def run_training_epoch(self):
# add metrics to progress_bar
self.add_progress_bar_metrics(_processed_outputs[1])

# when no val loop is present or fast-dev-run still need to call checkpoints
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
def sync_horovod(self):
if self.use_horovod:
hvd.join(hvd.local_rank() if self.on_gpu else -1)

def increment_accumulated_grad_global_step(self):
# progress global step according to grads progress
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
self.global_step += 1
self.total_batch_idx += 1

def save_train_loop_metrics_to_loggers(self, batch_idx, early_stop_epoch, batch_output):
# when metrics should be logged
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
if should_log_metrics or self.fast_dev_run:
# logs user requested information to logger
self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic)

def save_loggers_in_training_loop(self, batch_idx, early_stop_epoch):
# when loggers should save to disk
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
if self.is_global_zero and self.logger is not None:
self.logger.save()

def check_validation_in_train_loop(self, batch_idx, early_stop_epoch, is_last_batch):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
can_check_val = not self.disable_validation and can_check_epoch
should_check_val = is_val_check_batch or early_stop_epoch
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and should_check_val

# if we need to run validation, then also call the checkpoint callback
if self.fast_dev_run or should_check_val:
self.run_evaluation(test_mode=self.testing)
self.call_checkpoint_callback()

# Epoch end events
with self.profiler.profile('on_epoch_end'):
# callbacks
self.on_epoch_end()
# model hooks
if self.is_function_implemented('on_epoch_end'):
model.on_epoch_end()
return should_check_val

def run_training_batch(self, batch, batch_idx):
# track grad norms
Expand All @@ -561,7 +597,7 @@ def run_training_batch(self, batch, batch_idx):
batch_log_metrics = []

if batch is None:
return 0, grad_norm_dic, {}, {}
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)

# Batch start events
with self.profiler.profile('on_batch_start'):
Expand All @@ -571,7 +607,7 @@ def run_training_batch(self, batch, batch_idx):
if self.is_function_implemented('on_batch_start'):
response = self.get_model().on_batch_start(batch)
if response == -1:
return -1, grad_norm_dic, {}, {}
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

splits = [batch]
if self.truncated_bptt_steps is not None:
Expand Down Expand Up @@ -650,7 +686,13 @@ def run_training_batch(self, batch, batch_idx):
# track all metrics for callbacks
self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()})

return 0, grad_norm_dic, batch_log_metrics, opt_closure_result.training_step_output_for_epoch_end
result = AttributeDict(
signal=0,
grad_norm_dic=grad_norm_dic,
batch_log_metrics=batch_log_metrics,
training_step_output_for_epoch_end=opt_closure_result.training_step_output_for_epoch_end
)
return result

def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
# ------------------
Expand Down
36 changes: 16 additions & 20 deletions tests/trainer/test_trainer_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ def test_trainingstep_dict(tmpdir):
break

out = trainer.run_training_batch(batch, batch_idx)
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
assert signal == 0
assert all_log_metrics['log_acc1'] == 12.0
assert all_log_metrics['log_acc2'] == 7.0
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0

Expand All @@ -55,12 +54,11 @@ def training_step_with_step_end(tmpdir):
break

out = trainer.run_training_batch(batch, batch_idx)
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
assert signal == 0
assert all_log_metrics['log_acc1'] == 12.0
assert all_log_metrics['log_acc2'] == 7.0
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0

Expand Down Expand Up @@ -92,12 +90,11 @@ def test_full_training_loop_dict(tmpdir):
break

out = trainer.run_training_batch(batch, batch_idx)
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
assert signal == 0
assert all_log_metrics['log_acc1'] == 12.0
assert all_log_metrics['log_acc2'] == 7.0
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0

Expand Down Expand Up @@ -129,11 +126,10 @@ def test_train_step_epoch_end(tmpdir):
break

out = trainer.run_training_batch(batch, batch_idx)
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
assert signal == 0
assert all_log_metrics['log_acc1'] == 12.0
assert all_log_metrics['log_acc2'] == 7.0
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0

0 comments on commit 598f514

Please sign in to comment.