diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index bbe81853d06e6..7ac92bc585155 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -361,6 +361,14 @@ def reduce_across_time(cls, time_outputs): result['meta'] = meta return result + def dp_reduce(self): + for k, value in self.items(): + if k == 'meta': + continue + if isinstance(value, list): + value = torch.tensor(value) + self[k] = value.mean(dim=-1) + @property def should_reduce_on_epoch_end(self) -> bool: return self['meta']['_internal']['_reduce_on_epoch'] diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index d7f5f3eb0fd69..4ef1b34b3c3fa 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -343,17 +343,20 @@ def _evaluate( m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) + # ------------------ + # EVAL STEP END + # ------------------ # on dp / ddp2 might still want to do something with the batch parts - if test_mode: - if self.is_overridden('test_step_end'): - model_ref = self.get_model() - with self.profiler.profile('test_step_end'): - output = model_ref.test_step_end(output) - else: - if self.is_overridden('validation_step_end'): - model_ref = self.get_model() - with self.profiler.profile('validation_step_end'): - output = model_ref.validation_step_end(output) + eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end' + if self.is_overridden(eval_step_end_hook_name): + model_ref = self.get_model() + with self.profiler.profile(eval_step_end_hook_name): + eval_step_end = getattr(model_ref, eval_step_end_hook_name) + output = eval_step_end(output) + + elif is_result_obj and (self.use_dp or self.use_ddp2): + # result auto reduce + output.dp_reduce() # callbacks (on __batch_end) if test_mode: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index fed64c2e09aaf..8577f98f9d65f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -1221,6 +1221,8 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens): else: output = self.model.training_step(*args) + is_result_obj = isinstance(output, Result) + # allow any mode to define training_step_end # do something will all the dp outputs (like softmax) if self.is_overridden('training_step_end'): @@ -1229,6 +1231,9 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens): # TODO: modify when using result obj output = model_ref.training_step_end(output) + elif is_result_obj and (self.use_dp or self.use_ddp2): + output.dp_reduce() + # allow any mode to define training_end # TODO: remove in 1.0.0 if self.is_overridden('training_end'): diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index b6fc748037af3..6d7cd365d8c25 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -79,6 +79,28 @@ def training_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx= self.training_step_called = True return result + def training_step_result_obj_dp(self, batch, batch_idx, optimizer_idx=None): + # forward pass + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x.to(self.device)) + + # calculate loss + loss_val = self.loss(y.to(y_hat.device), y_hat) + log_val = loss_val + + # alternate between tensors and scalars for "log" and "progress_bar" + if batch_idx % 2 == 0: + log_val = log_val.item() + + result = TrainResult(loss_val) + result.log('some_val', log_val * log_val, prog_bar=True, logger=False) + result.log('train_some_val', log_val * log_val) + + self.training_step_called = True + + return result + def training_step_end_full_loop_result_obj_dp(self, result): """ Full loop flow train step (result obj + dp) diff --git a/tests/base/model_valid_steps.py b/tests/base/model_valid_steps.py index 74694d365545c..94b843e96a02a 100644 --- a/tests/base/model_valid_steps.py +++ b/tests/base/model_valid_steps.py @@ -52,6 +52,28 @@ def validation_step_result_obj(self, batch, batch_idx, *args, **kwargs): }) return result + def validation_step_result_obj_dp(self, batch, batch_idx, *args, **kwargs): + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x.to(self.device)) + + y = y.to(y_hat.device) + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc).type_as(x) + + result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val) + result.log_dict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + }) + + self.validation_step_called = True + return result + def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs): """ Lightning calls this inside the validation loop diff --git a/tests/trainer/test_trainer_steps_result_return.py b/tests/trainer/test_trainer_steps_result_return.py index ed3632b968c9f..50ddc48b41251 100644 --- a/tests/trainer/test_trainer_steps_result_return.py +++ b/tests/trainer/test_trainer_steps_result_return.py @@ -535,6 +535,41 @@ def test_full_train_loop_with_results_obj_dp(tmpdir): assert 'epoch_train_epoch_end_metric' in seen_keys +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_loop_steps_only_dp(tmpdir): + os.environ['PL_DEV_DEBUG'] = '1' + + batches = 10 + epochs = 3 + + model = EvalModelTemplate() + model.validation_step = None + model.test_step = None + model.training_step = model.training_step_result_obj_dp + model.training_step_end = None + model.training_epoch_end = None + model.validation_step = model.validation_step_result_obj_dp + model.validation_step_end = None + model.validation_epoch_end = None + model.test_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + distributed_backend='dp', + gpus=[0, 1], + max_epochs=epochs, + early_stop_callback=True, + row_log_interval=2, + limit_train_batches=batches, + weights_summary=None, + ) + + trainer.fit(model) + + assert model.training_step_called + assert model.validation_step_called + + def test_result_map(tmpdir): result = TrainResult() result.log_dict({'x1': torch.tensor(1), 'x2': torch.tensor(2)})