From 0f073819d3e0df8db7602eab489b1bad0fc0949c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 11:17:10 -0400 Subject: [PATCH] refactored training_batch + tests to verify correctness (#2328) * refactored training_bath * refactored training_bath * refactored training_bath * refactored training_bath * refactored training_bath * refactored training_bath * refactored training_bath * refactored training_bath * refactored training_bath * refactored training_bath * refactored training_bath --- pytorch_lightning/trainer/training_loop.py | 220 ++++++++++++++------- tests/base/deterministic_model.py | 144 ++++++++++++++ tests/trainer/test_trainer_steps.py | 139 +++++++++++++ 3 files changed, 428 insertions(+), 75 deletions(-) create mode 100644 tests/base/deterministic_model.py create mode 100644 tests/trainer/test_trainer_steps.py diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a7608e95019f3..80dde617a2669 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -162,6 +162,8 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.memory import recursive_detach import subprocess try: @@ -527,8 +529,14 @@ def run_training_epoch(self): _processed_outputs = self.process_output(epoch_output) log_epoch_metrics = _processed_outputs[2] callback_epoch_metrics = _processed_outputs[3] + + # add the metrics to the loggers self.log_metrics(log_epoch_metrics, {}) + + # add metrics to callbacks self.callback_metrics.update(callback_epoch_metrics) + + # add metrics to progress_bar self.add_progress_bar_metrics(_processed_outputs[1]) # when no val loop is present or fast-dev-run still need to call checkpoints @@ -548,10 +556,10 @@ def run_training_batch(self, batch, batch_idx): grad_norm_dic = {} # track all metrics for callbacks - all_callback_metrics = [] + batch_callback_metrics = [] # track metrics to log - all_log_metrics = [] + batch_log_metrics = [] if batch is None: return 0, grad_norm_dic, {}, {} @@ -586,87 +594,42 @@ def run_training_batch(self, batch, batch_idx): for param in group['params']: param.requires_grad = True - # wrap the forward step in a closure so second order methods work - def optimizer_closure(): - # forward pass - with self.profiler.profile('model_forward'): - if self.use_amp and self.use_native_amp: - with torch.cuda.amp.autocast(): - output_dict = self.training_forward(split_batch, batch_idx, - opt_idx, self.hiddens) - else: - output_dict = self.training_forward(split_batch, batch_idx, opt_idx, self.hiddens) - - # format and reduce outputs accordingly - processed_output = self.process_output(output_dict, train=True) - - closure_loss, progress_bar_metrics, log_metrics, callback_metrics, self.hiddens = processed_output - - # accumulate loss - # (if accumulate_grad_batches = 1 no effect) - closure_loss = closure_loss / self.accumulate_grad_batches - - # backward pass - model_ref = self.get_model() - with self.profiler.profile('model_backward'): - # scale loss for 16 bit - if self.precision == 16 and not self.on_tpu: - closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx) - - # do backward pass - model_ref.backward(self, closure_loss, optimizer, opt_idx) - - # track metrics for callbacks - all_callback_metrics.append(callback_metrics) - - # track progress bar metrics - self.add_progress_bar_metrics(progress_bar_metrics) - all_log_metrics.append(log_metrics) - - if self.use_horovod: - # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid - optimizer.synchronize() - - # insert after step hook - if self.is_function_implemented('on_after_backward'): - model_ref = self.get_model() - with self.profiler.profile('on_after_backward'): - model_ref.on_after_backward() + # ------------------- + # calculate loss + # ------------------- + opt_closure_result = self.optimizer_closure( + split_batch, + batch_idx, + opt_idx, + optimizer, + self.hiddens + ) - return closure_loss, callback_metrics + # ------------------------------ + # POST forward bookkeeping + # ------------------------------ + batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics) + batch_log_metrics.append(opt_closure_result.training_step_output.log_metrics) + self.add_progress_bar_metrics(opt_closure_result.training_step_output.pbar_on_batch_end) - # calculate loss - loss, batch_output = optimizer_closure() + # track hiddens + self.hiddens = opt_closure_result.hiddens # check if loss or model weights are nan if self.terminate_on_nan: - self.detect_nan_tensors(loss) + self.detect_nan_tensors(opt_closure_result.loss) # track total loss for logging (avoid mem leaks) - self.batch_loss_value.append(loss) + self.batch_loss_value.append(opt_closure_result.loss) + # ------------------------------ + # BACKWARD PASS + # ------------------------------ # gradient update with accumulated gradients if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: - # track gradient norms when requested - if batch_idx % self.row_log_interval == 0: - if float(self.track_grad_norm) > 0: - model = self.get_model() - grad_norm_dic = model.grad_norm( - self.track_grad_norm) - - # clip gradients - if self.use_amp and self.use_native_amp: - self.scaler.unscale_(optimizer) - self.clip_gradients() - - # calls .step(), .zero_grad() - # override function to modify this behavior - model = self.get_model() - with self.profiler.profile('optimizer_step'): - model.optimizer_step(self.current_epoch, batch_idx, - optimizer, opt_idx, - lambda: optimizer_closure()[0]) + # backward + grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer) # calculate running loss for display self.running_loss.append(self.batch_loss_value.mean()) @@ -683,12 +646,119 @@ def optimizer_closure(): self.get_model().on_batch_end() # collapse all metrics into one dict - all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} + batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} # track all metrics for callbacks - self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()}) + self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()}) - return 0, grad_norm_dic, all_log_metrics, batch_output + return 0, grad_norm_dic, batch_log_metrics, opt_closure_result.training_step_output_for_epoch_end + + def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): + # ------------------ + # GRAD NORMS + # ------------------ + # track gradient norms when requested + grad_norm_dic = {} + if batch_idx % self.row_log_interval == 0: + if float(self.track_grad_norm) > 0: + model = self.get_model() + grad_norm_dic = model.grad_norm( + self.track_grad_norm) + + # ------------------ + # CLIP GRADS + # ------------------ + if self.use_amp and self.use_native_amp: + self.scaler.unscale_(optimizer) + self.clip_gradients() + + # ------------------ + # .STEP + ZERO_GRAD + # ------------------ + model = self.get_model() + with self.profiler.profile('optimizer_step'): + lambda_closure = lambda: self.optimizer_closure( + split_batch, + batch_idx, + opt_idx, + optimizer, + self.hiddens + ).loss + model.optimizer_step(self.current_epoch, batch_idx, + optimizer, opt_idx, + lambda_closure) + + return grad_norm_dic + + def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): + """ + wrap the forward step in a closure so second order methods work + """ + # --------------------------- + # FORWARD + # --------------------------- + with self.profiler.profile('model_forward'): + if self.use_amp and self.use_native_amp: + with torch.cuda.amp.autocast(): + training_step_output = self.training_forward(split_batch, batch_idx, + opt_idx, hiddens) + else: + training_step_output = self.training_forward(split_batch, batch_idx, opt_idx, + hiddens) + + # ---------------------------- + # PROCESS THE RESULT + # ---------------------------- + # format and reduce outputs accordingly + training_step_output = self.process_output(training_step_output, train=True) + + # TODO: temporary part of structured results PR + training_step_output = AttributeDict( + batch_loss=training_step_output[0], + pbar_on_batch_end=training_step_output[1], + log_metrics=training_step_output[2], + callback_metrics=training_step_output[3], + hiddens=training_step_output[4], + ) + + # if the user decides to finally reduce things in epoch_end, save raw output without graphs + training_step_output_for_epoch_end = recursive_detach(training_step_output) + + # accumulate loss + # (if accumulate_grad_batches = 1 no effect) + closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches + + # backward pass + model_ref = self.get_model() + with self.profiler.profile('model_backward'): + # scale loss for 16 bit + if self.precision == 16 and not self.on_tpu: + closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx) + + # do backward pass + model_ref.backward(self, closure_loss, optimizer, opt_idx) + + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + training_step_output.batch_loss = training_step_output.batch_loss.detach() + + if self.use_horovod: + # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid + optimizer.synchronize() + + # insert after step hook + if self.is_function_implemented('on_after_backward'): + model_ref = self.get_model() + with self.profiler.profile('on_after_backward'): + model_ref.on_after_backward() + + result = AttributeDict( + loss=closure_loss, + training_step_output=training_step_output, + training_step_output_for_epoch_end=training_step_output_for_epoch_end, + hiddens=training_step_output.hiddens, + ) + return result def _get_optimizers_iterable(self): if not self.optimizer_frequencies: diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py new file mode 100644 index 0000000000000..1ca318ef1fac8 --- /dev/null +++ b/tests/base/deterministic_model.py @@ -0,0 +1,144 @@ +import torch +from pytorch_lightning.core.lightning import LightningModule +from torch.utils.data import Dataset, DataLoader +import numpy as np + + +class DeterministicModel(LightningModule): + + def __init__(self, weights=None): + super().__init__() + + self.training_step_called = False + self.training_step_end_called = False + self.training_epoch_end_called = False + + if weights is None: + weights = torch.tensor([ + [4, 3, 5], + [10, 11, 13] + ]).float() + self.l1 = torch.nn.Parameter(weights, requires_grad=True) + + def forward(self, x): + return self.l1.mm(x.float().t()) + + def step(self, batch, batch_idx): + x = batch + y_hat = self(x) + + assert torch.all(y_hat[0, :] == 15.0) + assert torch.all(y_hat[1, :] == 42.0) + out = y_hat.sum() + assert out == (42.0 * 3) + (15.0 * 3) + + return out + + def assert_graph_count(self, result, count=1): + counts = self.count_num_graphs(result) + assert counts == count + + def count_num_graphs(self, result, num_graphs=0): + for k, v in result.items(): + if isinstance(v, torch.Tensor) and v.grad_fn is not None: + num_graphs += 1 + if isinstance(v, dict): + num_graphs += self.count_num_graphs(v) + + return num_graphs + + # -------------------------- + # dictionary returns + # -------------------------- + def training_step_dict_return(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} + pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} + + self.training_step_called = True + return {'loss': acc, 'log': logs, 'progress_bar': pbar} + + def training_step_for_step_end_dict(self, batch, batch_idx): + """sends outputs to training_batch_end""" + acc = self.step(batch, batch_idx) + + logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} + pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} + + self.training_step_called = True + result = {'loss': acc} + result.update(logs) + result.update(pbar) + return result + + def training_step_end_dict(self, output): + self.training_step_end_called = True + + # make sure loss has the grad + assert 'loss' in output + assert output['loss'].grad_fn is not None + + # make sure nothing else has grads + assert self.count_num_graphs(output) == 1 + + # make sure the other keys are there + assert 'log_acc1' in output + assert 'log_acc2' in output + assert 'pbar_acc1' in output + assert 'pbar_acc2' in output + + logs = {'log_acc1': output['log_acc1'], 'log_acc2': output['log_acc2']} + pbar = {'pbar_acc1': output['pbar_acc1'], 'pbar_acc2': output['pbar_acc2']} + + acc = output['loss'] + return {'loss': acc, 'log': logs, 'progress_bar': pbar} + + def training_epoch_end_dict(self, outputs): + self.training_epoch_end_called = True + + if self.use_dp or self.use_ddp2: + pass + else: + # only saw 4 batches + assert len(outputs) == 4 + for batch_out in outputs: + assert len(batch_out.keys()) == 5 + keys = ['batch_loss', 'pbar_on_batch_end', 'log_metrics', 'callback_metrics'] + for key in keys: + assert key in batch_out + + prototype_loss = outputs[0]['batch_loss'] + logs = {'epoch_end_log_1': torch.tensor(178).type_as(prototype_loss)} + pbar = {'epoch_end_pbar_1': torch.tensor(234).type_as(prototype_loss)} + + return {'log': logs, 'progress_bar': pbar} + + def validation_step_dict_return(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} + pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} + return {'val_loss': acc, 'log': logs, 'progress_bar': pbar} + + def train_dataloader(self): + return DataLoader(DummyDataset(), batch_size=3, shuffle=False) + + def val_dataloader(self): + return DataLoader(DummyDataset(), batch_size=3, shuffle=False) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0) + + def backward(self, trainer, loss, optimizer, optimizer_idx): + assert loss == 171.0 + loss.backward() + + +class DummyDataset(Dataset): + + def __len__(self): + return 12 + + def __getitem__(self, idx): + return np.array([0.5, 1.0, 2.0]) diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py new file mode 100644 index 0000000000000..7e23324eed192 --- /dev/null +++ b/tests/trainer/test_trainer_steps.py @@ -0,0 +1,139 @@ +from pytorch_lightning import Trainer +from tests.base.deterministic_model import DeterministicModel + + +def test_trainingstep_dict(tmpdir): + """ + Tests that only training_step can be used + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.val_dataloader = None + + trainer = Trainer(fast_dev_run=True, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out + assert signal == 0 + assert all_log_metrics['log_acc1'] == 12.0 + assert all_log_metrics['log_acc2'] == 7.0 + + pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + assert pbar_metrics['pbar_acc1'] == 17.0 + assert pbar_metrics['pbar_acc2'] == 19.0 + + +def training_step_with_step_end(tmpdir): + """ + Checks train_step + training_step_end + """ + model = DeterministicModel() + model.training_step = model.training_step_for_step_end_dict + model.training_step_end = model.training_step_end_dict + model.val_dataloader = None + + trainer = Trainer(fast_dev_run=True, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out + assert signal == 0 + assert all_log_metrics['log_acc1'] == 12.0 + assert all_log_metrics['log_acc2'] == 7.0 + + pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + assert pbar_metrics['pbar_acc1'] == 17.0 + assert pbar_metrics['pbar_acc2'] == 19.0 + + +def test_full_training_loop_dict(tmpdir): + """ + Checks train_step + training_step_end + training_epoch_end + """ + model = DeterministicModel() + model.training_step = model.training_step_for_step_end_dict + model.training_step_end = model.training_step_end_dict + model.training_epoch_end = model.training_epoch_end_dict + model.val_dataloader = None + + trainer = Trainer(max_epochs=1, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert trainer.callback_metrics['epoch_end_log_1'] == 178 + assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234 + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out + assert signal == 0 + assert all_log_metrics['log_acc1'] == 12.0 + assert all_log_metrics['log_acc2'] == 7.0 + + pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + assert pbar_metrics['pbar_acc1'] == 17.0 + assert pbar_metrics['pbar_acc2'] == 19.0 + + +def test_train_step_epoch_end(tmpdir): + """ + Checks train_step + training_epoch_end (NO training_step_end) + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.training_step_end = None + model.training_epoch_end = model.training_epoch_end_dict + model.val_dataloader = None + + trainer = Trainer(max_epochs=1, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert trainer.callback_metrics['epoch_end_log_1'] == 178 + assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234 + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out + assert signal == 0 + assert all_log_metrics['log_acc1'] == 12.0 + assert all_log_metrics['log_acc2'] == 7.0 + + pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + assert pbar_metrics['pbar_acc1'] == 17.0 + assert pbar_metrics['pbar_acc2'] == 19.0