From 1c7cf11cc7ae2e75b8e0a039b1847a8e9e037c54 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 18:36:08 -0400 Subject: [PATCH 1/8] refactoring training epoch --- pytorch_lightning/trainer/training_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2e30f8043235a..e8e98526c334b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -439,7 +439,7 @@ def run_training_epoch(self): train_dataloader = train_dataloader.per_device_loader(device) # bookkeeping - outputs = [] + epoch_output = [] # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( @@ -462,7 +462,7 @@ def run_training_epoch(self): # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory if self.is_overridden('training_epoch_end', model=self.get_model()): - outputs.append(batch_output) + epoch_output.append(batch_output) # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 @@ -524,7 +524,7 @@ def run_training_epoch(self): # process epoch outputs model = self.get_model() if self.is_overridden('training_epoch_end', model=model): - epoch_output = model.training_epoch_end(outputs) + epoch_output = model.training_epoch_end(epoch_output) _processed_outputs = self.process_output(epoch_output) log_epoch_metrics = _processed_outputs[2] callback_epoch_metrics = _processed_outputs[3] From b9e123b7b1191c2f2b04ce4d83cc75a8ecf6254a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 18:47:20 -0400 Subject: [PATCH 2/8] refactored training epoch --- pytorch_lightning/trainer/training_loop.py | 27 +++++++++++++--------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e8e98526c334b..11d631fa018a3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -453,19 +453,18 @@ def run_training_epoch(self): model.global_step = self.global_step - # --------------- - # RUN TRAIN STEP - # --------------- - _outputs = self.run_training_batch(batch, batch_idx) - batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs + # ------------------------------------ + # TRAINING_STEP + TRAINING_STEP_END + # ------------------------------------ + batch_output = self.run_training_batch(batch, batch_idx) # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory if self.is_overridden('training_epoch_end', model=self.get_model()): - epoch_output.append(batch_output) + epoch_output.append(batch_output.training_step_output_for_epoch_end) # when returning -1 from train_step, we end epoch early - early_stop_epoch = batch_result == -1 + early_stop_epoch = batch_output.signal == -1 # TODO: consolidate all actions that need to take place only after # self.accumulate_grad_batches steps (optimizer step, lr update, global step increment) @@ -501,7 +500,7 @@ def run_training_epoch(self): should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch if should_log_metrics or self.fast_dev_run: # logs user requested information to logger - self.log_metrics(batch_step_metrics, grad_norm_dic) + self.log_metrics(batch_output.batch_step_metrics, batch_output.grad_norm_dic) # progress global step according to grads progress if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: @@ -561,7 +560,7 @@ def run_training_batch(self, batch, batch_idx): batch_log_metrics = [] if batch is None: - return 0, grad_norm_dic, {}, {} + return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # Batch start events with self.profiler.profile('on_batch_start'): @@ -571,7 +570,7 @@ def run_training_batch(self, batch, batch_idx): if self.is_function_implemented('on_batch_start'): response = self.get_model().on_batch_start(batch) if response == -1: - return -1, grad_norm_dic, {}, {} + return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) splits = [batch] if self.truncated_bptt_steps is not None: @@ -650,7 +649,13 @@ def run_training_batch(self, batch, batch_idx): # track all metrics for callbacks self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()}) - return 0, grad_norm_dic, batch_log_metrics, opt_closure_result.training_step_output_for_epoch_end + result = AttributeDict( + signal=0, + grad_norm_dic=grad_norm_dic, + batch_log_metrics=batch_log_metrics, + training_step_output_for_epoch_end=opt_closure_result.training_step_output_for_epoch_end + ) + return result def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): # ------------------ From e36e9bf820dbd1e58e03d183a7e9c352acdf3a52 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 19:10:47 -0400 Subject: [PATCH 3/8] refactored training epoch --- pytorch_lightning/trainer/training_loop.py | 130 +++++++++++++-------- 1 file changed, 80 insertions(+), 50 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 11d631fa018a3..4546e8305992c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -450,7 +450,6 @@ def run_training_epoch(self): break self.batch_idx = batch_idx - model.global_step = self.global_step # ------------------------------------ @@ -463,49 +462,29 @@ def run_training_epoch(self): if self.is_overridden('training_epoch_end', model=self.get_model()): epoch_output.append(batch_output.training_step_output_for_epoch_end) + # update LR schedulers + self.update_train_loop_lr_schedulers() + # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_output.signal == -1 - # TODO: consolidate all actions that need to take place only after - # self.accumulate_grad_batches steps (optimizer step, lr update, global step increment) - if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: - # update lr - self.update_learning_rates(interval='step') - - # --------------- - # RUN VAL STEP - # --------------- - is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 - can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - can_check_val = not self.disable_validation and can_check_epoch - should_check_val = is_val_check_batch or early_stop_epoch - should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf')) - should_check_val = can_check_val and should_check_val - - # --------------- - # CHECKPOINTING, EARLY STOPPING - # --------------- - # fast_dev_run always forces val checking after train batch - if self.fast_dev_run or should_check_val: - self.run_evaluation(test_mode=self.testing) - self.call_checkpoint_callback() - - # when logs should be saved - should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch - if should_save_log or self.fast_dev_run: - if self.is_global_zero and self.logger is not None: - self.logger.save() - - # when metrics should be logged - should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch - if should_log_metrics or self.fast_dev_run: - # logs user requested information to logger - self.log_metrics(batch_output.batch_step_metrics, batch_output.grad_norm_dic) + # ----------------------------------------- + # VALIDATE IF NEEDED + CHECKPOINT CALLBACK + # ----------------------------------------- + should_check_val = self.check_validation_in_train_loop(batch_idx, early_stop_epoch, is_last_batch) + + # ----------------------------------------- + # SAVE LOGGERS (ie: Tensorboard, etc...) + # ----------------------------------------- + self.save_loggers_in_training_loop(batch_idx, early_stop_epoch) + + # ----------------------------------------- + # SAVE METRICS TO LOGGERS + # ----------------------------------------- + self.save_train_loop_metrics_to_loggers(batch_idx, early_stop_epoch, batch_output) # progress global step according to grads progress - if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: - self.global_step += 1 - self.total_batch_idx += 1 + self.increment_accumulated_grad_global_step() # max steps reached, end training if self.max_steps is not None and self.max_steps == self.global_step: @@ -517,10 +496,33 @@ def run_training_epoch(self): if early_stop_epoch or self.fast_dev_run: break - if self.use_horovod: - hvd.join(hvd.local_rank() if self.on_gpu else -1) + # let ddp devices catch up when using horovod + self.sync_horovod() # process epoch outputs + self.run_training_epoch_end(epoch_output) + + # when no val loop is present or fast-dev-run still need to call checkpoints + if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val): + self.call_checkpoint_callback() + + # epoch end hook + self.run_on_epoch_end_hook(model) + + def update_train_loop_lr_schedulers(self): + if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: + # update lr + self.update_learning_rates(interval='step') + + def run_on_epoch_end_hook(self, model): + with self.profiler.profile('on_epoch_end'): + # callbacks + self.on_epoch_end() + # model hooks + if self.is_function_implemented('on_epoch_end'): + model.on_epoch_end() + + def run_training_epoch_end(self, epoch_output): model = self.get_model() if self.is_overridden('training_epoch_end', model=model): epoch_output = model.training_epoch_end(epoch_output) @@ -537,17 +539,45 @@ def run_training_epoch(self): # add metrics to progress_bar self.add_progress_bar_metrics(_processed_outputs[1]) - # when no val loop is present or fast-dev-run still need to call checkpoints - if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val): + def sync_horovod(self): + if self.use_horovod: + hvd.join(hvd.local_rank() if self.on_gpu else -1) + + def increment_accumulated_grad_global_step(self): + # progress global step according to grads progress + if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: + self.global_step += 1 + self.total_batch_idx += 1 + + def save_train_loop_metrics_to_loggers(self, batch_idx, early_stop_epoch, batch_output): + # when metrics should be logged + should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch + if should_log_metrics or self.fast_dev_run: + # logs user requested information to logger + self.log_metrics(batch_output.batch_step_metrics, batch_output.grad_norm_dic) + + def save_loggers_in_training_loop(self, batch_idx, early_stop_epoch): + # when loggers should save to disk + should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch + if should_save_log or self.fast_dev_run: + if self.is_global_zero and self.logger is not None: + self.logger.save() + + def check_validation_in_train_loop(self, batch_idx, early_stop_epoch, is_last_batch): + # decide if we should run validation + is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 + can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 + can_check_val = not self.disable_validation and can_check_epoch + should_check_val = is_val_check_batch or early_stop_epoch + should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf')) + should_check_val = can_check_val and should_check_val + + # if we need to run validation, then also call the checkpoint callback + if self.fast_dev_run or should_check_val: + self.run_evaluation(test_mode=self.testing) self.call_checkpoint_callback() - # Epoch end events - with self.profiler.profile('on_epoch_end'): - # callbacks - self.on_epoch_end() - # model hooks - if self.is_function_implemented('on_epoch_end'): - model.on_epoch_end() + return should_check_val def run_training_batch(self, batch, batch_idx): # track grad norms From 593706ce23107da1e623a7e03078d5b0291cbadb Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 19:13:53 -0400 Subject: [PATCH 4/8] refactored training epoch --- pytorch_lightning/trainer/training_loop.py | 27 ++++++++++++++-------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4546e8305992c..4da21eb6fcb85 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -415,11 +415,16 @@ def train(self): self.run_training_teardown() - def run_training_epoch(self): + def prepare_train_loop_dataloader(self, train_dataloader): + # on TPU we have to wrap it under the ParallelLoader + if self.use_tpu: + device = xm.xla_device(self.tpu_id) + train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) + train_dataloader = train_dataloader.per_device_loader(device) - # get model - model = self.get_model() + return train_dataloader + def run_on_epoch_start_hook(self): # Epoch start events with self.profiler.profile('on_epoch_start'): # callbacks @@ -429,14 +434,16 @@ def run_training_epoch(self): if self.is_function_implemented('on_epoch_start'): model.on_epoch_start() - # track local dataloader so TPU can wrap each epoch - train_dataloader = self.train_dataloader + def run_training_epoch(self): - # on TPU we have to wrap it under the ParallelLoader - if self.use_tpu: - device = xm.xla_device(self.tpu_id) - train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) - train_dataloader = train_dataloader.per_device_loader(device) + # get model + model = self.get_model() + + # Epoch start events + self.run_on_epoch_start_hook() + + # modify dataloader if needed (ddp, etc...) + train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader) # bookkeeping epoch_output = [] From b8cb395bbf18a9818da8a00d5802b657ef98583a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 19:14:22 -0400 Subject: [PATCH 5/8] refactored training epoch --- pytorch_lightning/trainer/training_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4da21eb6fcb85..ce6edfbad0a93 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -424,7 +424,7 @@ def prepare_train_loop_dataloader(self, train_dataloader): return train_dataloader - def run_on_epoch_start_hook(self): + def run_on_epoch_start_hook(self, model): # Epoch start events with self.profiler.profile('on_epoch_start'): # callbacks @@ -440,7 +440,7 @@ def run_training_epoch(self): model = self.get_model() # Epoch start events - self.run_on_epoch_start_hook() + self.run_on_epoch_start_hook(model) # modify dataloader if needed (ddp, etc...) train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader) From 136d76255d55696185444a27c819401996958c0a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 20:07:15 -0400 Subject: [PATCH 6/8] refactored training epoch --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ce6edfbad0a93..c113c04c15ab2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -561,7 +561,7 @@ def save_train_loop_metrics_to_loggers(self, batch_idx, early_stop_epoch, batch_ should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch if should_log_metrics or self.fast_dev_run: # logs user requested information to logger - self.log_metrics(batch_output.batch_step_metrics, batch_output.grad_norm_dic) + self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic) def save_loggers_in_training_loop(self, batch_idx, early_stop_epoch): # when loggers should save to disk From 848eb8153075a1096bd92550704a1a01155109ce Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 20:22:44 -0400 Subject: [PATCH 7/8] fixes slurm weights saving --- tests/trainer/test_trainer_steps.py | 36 +++++++++++++---------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index 7e23324eed192..8fd14e281df01 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -23,12 +23,11 @@ def test_trainingstep_dict(tmpdir): break out = trainer.run_training_batch(batch, batch_idx) - signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out - assert signal == 0 - assert all_log_metrics['log_acc1'] == 12.0 - assert all_log_metrics['log_acc2'] == 7.0 + assert out.signal == 0 + assert out.all_log_metrics['log_acc1'] == 12.0 + assert out.all_log_metrics['log_acc2'] == 7.0 - pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 assert pbar_metrics['pbar_acc2'] == 19.0 @@ -55,12 +54,11 @@ def training_step_with_step_end(tmpdir): break out = trainer.run_training_batch(batch, batch_idx) - signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out - assert signal == 0 - assert all_log_metrics['log_acc1'] == 12.0 - assert all_log_metrics['log_acc2'] == 7.0 + assert out.signal == 0 + assert out.all_log_metrics['log_acc1'] == 12.0 + assert out.all_log_metrics['log_acc2'] == 7.0 - pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 assert pbar_metrics['pbar_acc2'] == 19.0 @@ -92,12 +90,11 @@ def test_full_training_loop_dict(tmpdir): break out = trainer.run_training_batch(batch, batch_idx) - signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out - assert signal == 0 - assert all_log_metrics['log_acc1'] == 12.0 - assert all_log_metrics['log_acc2'] == 7.0 + assert out.signal == 0 + assert out.all_log_metrics['log_acc1'] == 12.0 + assert out.all_log_metrics['log_acc2'] == 7.0 - pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 assert pbar_metrics['pbar_acc2'] == 19.0 @@ -129,11 +126,10 @@ def test_train_step_epoch_end(tmpdir): break out = trainer.run_training_batch(batch, batch_idx) - signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out - assert signal == 0 - assert all_log_metrics['log_acc1'] == 12.0 - assert all_log_metrics['log_acc2'] == 7.0 + assert out.signal == 0 + assert out.all_log_metrics['log_acc1'] == 12.0 + assert out.all_log_metrics['log_acc2'] == 7.0 - pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 assert pbar_metrics['pbar_acc2'] == 19.0 From 0da220c345e1f5b60357007fa9293d2fe635825d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 20:31:33 -0400 Subject: [PATCH 8/8] fixes slurm weights saving --- tests/trainer/test_trainer_steps.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index 8fd14e281df01..88ff4f8c7a16e 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -24,8 +24,8 @@ def test_trainingstep_dict(tmpdir): out = trainer.run_training_batch(batch, batch_idx) assert out.signal == 0 - assert out.all_log_metrics['log_acc1'] == 12.0 - assert out.all_log_metrics['log_acc2'] == 7.0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 @@ -55,8 +55,8 @@ def training_step_with_step_end(tmpdir): out = trainer.run_training_batch(batch, batch_idx) assert out.signal == 0 - assert out.all_log_metrics['log_acc1'] == 12.0 - assert out.all_log_metrics['log_acc2'] == 7.0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 @@ -91,8 +91,8 @@ def test_full_training_loop_dict(tmpdir): out = trainer.run_training_batch(batch, batch_idx) assert out.signal == 0 - assert out.all_log_metrics['log_acc1'] == 12.0 - assert out.all_log_metrics['log_acc2'] == 7.0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 @@ -127,8 +127,8 @@ def test_train_step_epoch_end(tmpdir): out = trainer.run_training_batch(batch, batch_idx) assert out.signal == 0 - assert out.all_log_metrics['log_acc1'] == 12.0 - assert out.all_log_metrics['log_acc2'] == 7.0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0