diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2e30f8043235a..c113c04c15ab2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -415,11 +415,16 @@ def train(self): self.run_training_teardown() - def run_training_epoch(self): + def prepare_train_loop_dataloader(self, train_dataloader): + # on TPU we have to wrap it under the ParallelLoader + if self.use_tpu: + device = xm.xla_device(self.tpu_id) + train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) + train_dataloader = train_dataloader.per_device_loader(device) - # get model - model = self.get_model() + return train_dataloader + def run_on_epoch_start_hook(self, model): # Epoch start events with self.profiler.profile('on_epoch_start'): # callbacks @@ -429,17 +434,19 @@ def run_training_epoch(self): if self.is_function_implemented('on_epoch_start'): model.on_epoch_start() - # track local dataloader so TPU can wrap each epoch - train_dataloader = self.train_dataloader + def run_training_epoch(self): - # on TPU we have to wrap it under the ParallelLoader - if self.use_tpu: - device = xm.xla_device(self.tpu_id) - train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) - train_dataloader = train_dataloader.per_device_loader(device) + # get model + model = self.get_model() + + # Epoch start events + self.run_on_epoch_start_hook(model) + + # modify dataloader if needed (ddp, etc...) + train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader) # bookkeeping - outputs = [] + epoch_output = [] # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( @@ -450,63 +457,41 @@ def run_training_epoch(self): break self.batch_idx = batch_idx - model.global_step = self.global_step - # --------------- - # RUN TRAIN STEP - # --------------- - _outputs = self.run_training_batch(batch, batch_idx) - batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs + # ------------------------------------ + # TRAINING_STEP + TRAINING_STEP_END + # ------------------------------------ + batch_output = self.run_training_batch(batch, batch_idx) # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory if self.is_overridden('training_epoch_end', model=self.get_model()): - outputs.append(batch_output) + epoch_output.append(batch_output.training_step_output_for_epoch_end) + + # update LR schedulers + self.update_train_loop_lr_schedulers() # when returning -1 from train_step, we end epoch early - early_stop_epoch = batch_result == -1 - - # TODO: consolidate all actions that need to take place only after - # self.accumulate_grad_batches steps (optimizer step, lr update, global step increment) - if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: - # update lr - self.update_learning_rates(interval='step') - - # --------------- - # RUN VAL STEP - # --------------- - is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 - can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - can_check_val = not self.disable_validation and can_check_epoch - should_check_val = is_val_check_batch or early_stop_epoch - should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf')) - should_check_val = can_check_val and should_check_val - - # --------------- - # CHECKPOINTING, EARLY STOPPING - # --------------- - # fast_dev_run always forces val checking after train batch - if self.fast_dev_run or should_check_val: - self.run_evaluation(test_mode=self.testing) - self.call_checkpoint_callback() - - # when logs should be saved - should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch - if should_save_log or self.fast_dev_run: - if self.is_global_zero and self.logger is not None: - self.logger.save() - - # when metrics should be logged - should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch - if should_log_metrics or self.fast_dev_run: - # logs user requested information to logger - self.log_metrics(batch_step_metrics, grad_norm_dic) + early_stop_epoch = batch_output.signal == -1 + + # ----------------------------------------- + # VALIDATE IF NEEDED + CHECKPOINT CALLBACK + # ----------------------------------------- + should_check_val = self.check_validation_in_train_loop(batch_idx, early_stop_epoch, is_last_batch) + + # ----------------------------------------- + # SAVE LOGGERS (ie: Tensorboard, etc...) + # ----------------------------------------- + self.save_loggers_in_training_loop(batch_idx, early_stop_epoch) + + # ----------------------------------------- + # SAVE METRICS TO LOGGERS + # ----------------------------------------- + self.save_train_loop_metrics_to_loggers(batch_idx, early_stop_epoch, batch_output) # progress global step according to grads progress - if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: - self.global_step += 1 - self.total_batch_idx += 1 + self.increment_accumulated_grad_global_step() # max steps reached, end training if self.max_steps is not None and self.max_steps == self.global_step: @@ -518,13 +503,36 @@ def run_training_epoch(self): if early_stop_epoch or self.fast_dev_run: break - if self.use_horovod: - hvd.join(hvd.local_rank() if self.on_gpu else -1) + # let ddp devices catch up when using horovod + self.sync_horovod() # process epoch outputs + self.run_training_epoch_end(epoch_output) + + # when no val loop is present or fast-dev-run still need to call checkpoints + if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val): + self.call_checkpoint_callback() + + # epoch end hook + self.run_on_epoch_end_hook(model) + + def update_train_loop_lr_schedulers(self): + if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: + # update lr + self.update_learning_rates(interval='step') + + def run_on_epoch_end_hook(self, model): + with self.profiler.profile('on_epoch_end'): + # callbacks + self.on_epoch_end() + # model hooks + if self.is_function_implemented('on_epoch_end'): + model.on_epoch_end() + + def run_training_epoch_end(self, epoch_output): model = self.get_model() if self.is_overridden('training_epoch_end', model=model): - epoch_output = model.training_epoch_end(outputs) + epoch_output = model.training_epoch_end(epoch_output) _processed_outputs = self.process_output(epoch_output) log_epoch_metrics = _processed_outputs[2] callback_epoch_metrics = _processed_outputs[3] @@ -538,17 +546,45 @@ def run_training_epoch(self): # add metrics to progress_bar self.add_progress_bar_metrics(_processed_outputs[1]) - # when no val loop is present or fast-dev-run still need to call checkpoints - if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val): + def sync_horovod(self): + if self.use_horovod: + hvd.join(hvd.local_rank() if self.on_gpu else -1) + + def increment_accumulated_grad_global_step(self): + # progress global step according to grads progress + if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: + self.global_step += 1 + self.total_batch_idx += 1 + + def save_train_loop_metrics_to_loggers(self, batch_idx, early_stop_epoch, batch_output): + # when metrics should be logged + should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch + if should_log_metrics or self.fast_dev_run: + # logs user requested information to logger + self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic) + + def save_loggers_in_training_loop(self, batch_idx, early_stop_epoch): + # when loggers should save to disk + should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch + if should_save_log or self.fast_dev_run: + if self.is_global_zero and self.logger is not None: + self.logger.save() + + def check_validation_in_train_loop(self, batch_idx, early_stop_epoch, is_last_batch): + # decide if we should run validation + is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 + can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 + can_check_val = not self.disable_validation and can_check_epoch + should_check_val = is_val_check_batch or early_stop_epoch + should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf')) + should_check_val = can_check_val and should_check_val + + # if we need to run validation, then also call the checkpoint callback + if self.fast_dev_run or should_check_val: + self.run_evaluation(test_mode=self.testing) self.call_checkpoint_callback() - # Epoch end events - with self.profiler.profile('on_epoch_end'): - # callbacks - self.on_epoch_end() - # model hooks - if self.is_function_implemented('on_epoch_end'): - model.on_epoch_end() + return should_check_val def run_training_batch(self, batch, batch_idx): # track grad norms @@ -561,7 +597,7 @@ def run_training_batch(self, batch, batch_idx): batch_log_metrics = [] if batch is None: - return 0, grad_norm_dic, {}, {} + return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # Batch start events with self.profiler.profile('on_batch_start'): @@ -571,7 +607,7 @@ def run_training_batch(self, batch, batch_idx): if self.is_function_implemented('on_batch_start'): response = self.get_model().on_batch_start(batch) if response == -1: - return -1, grad_norm_dic, {}, {} + return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) splits = [batch] if self.truncated_bptt_steps is not None: @@ -650,7 +686,13 @@ def run_training_batch(self, batch, batch_idx): # track all metrics for callbacks self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()}) - return 0, grad_norm_dic, batch_log_metrics, opt_closure_result.training_step_output_for_epoch_end + result = AttributeDict( + signal=0, + grad_norm_dic=grad_norm_dic, + batch_log_metrics=batch_log_metrics, + training_step_output_for_epoch_end=opt_closure_result.training_step_output_for_epoch_end + ) + return result def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): # ------------------ diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index 7e23324eed192..88ff4f8c7a16e 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -23,12 +23,11 @@ def test_trainingstep_dict(tmpdir): break out = trainer.run_training_batch(batch, batch_idx) - signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out - assert signal == 0 - assert all_log_metrics['log_acc1'] == 12.0 - assert all_log_metrics['log_acc2'] == 7.0 + assert out.signal == 0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 - pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 assert pbar_metrics['pbar_acc2'] == 19.0 @@ -55,12 +54,11 @@ def training_step_with_step_end(tmpdir): break out = trainer.run_training_batch(batch, batch_idx) - signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out - assert signal == 0 - assert all_log_metrics['log_acc1'] == 12.0 - assert all_log_metrics['log_acc2'] == 7.0 + assert out.signal == 0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 - pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 assert pbar_metrics['pbar_acc2'] == 19.0 @@ -92,12 +90,11 @@ def test_full_training_loop_dict(tmpdir): break out = trainer.run_training_batch(batch, batch_idx) - signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out - assert signal == 0 - assert all_log_metrics['log_acc1'] == 12.0 - assert all_log_metrics['log_acc2'] == 7.0 + assert out.signal == 0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 - pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 assert pbar_metrics['pbar_acc2'] == 19.0 @@ -129,11 +126,10 @@ def test_train_step_epoch_end(tmpdir): break out = trainer.run_training_batch(batch, batch_idx) - signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out - assert signal == 0 - assert all_log_metrics['log_acc1'] == 12.0 - assert all_log_metrics['log_acc2'] == 7.0 + assert out.signal == 0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 - pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 assert pbar_metrics['pbar_acc2'] == 19.0