From 6d2e3e5cf8a994e6fd4b3f2f4b7f2c467e790b1d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 08:09:36 -0400 Subject: [PATCH 01/11] refactored training_bath --- pytorch_lightning/trainer/training_loop.py | 209 ++++++++++++++------- 1 file changed, 136 insertions(+), 73 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a7608e95019f3..10c533edc9922 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -162,6 +162,8 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.memory import recursive_detach import subprocess try: @@ -548,7 +550,7 @@ def run_training_batch(self, batch, batch_idx): grad_norm_dic = {} # track all metrics for callbacks - all_callback_metrics = [] + batch_callback_metrics = [] # track metrics to log all_log_metrics = [] @@ -586,87 +588,41 @@ def run_training_batch(self, batch, batch_idx): for param in group['params']: param.requires_grad = True - # wrap the forward step in a closure so second order methods work - def optimizer_closure(): - # forward pass - with self.profiler.profile('model_forward'): - if self.use_amp and self.use_native_amp: - with torch.cuda.amp.autocast(): - output_dict = self.training_forward(split_batch, batch_idx, - opt_idx, self.hiddens) - else: - output_dict = self.training_forward(split_batch, batch_idx, opt_idx, self.hiddens) - - # format and reduce outputs accordingly - processed_output = self.process_output(output_dict, train=True) - - closure_loss, progress_bar_metrics, log_metrics, callback_metrics, self.hiddens = processed_output - - # accumulate loss - # (if accumulate_grad_batches = 1 no effect) - closure_loss = closure_loss / self.accumulate_grad_batches - - # backward pass - model_ref = self.get_model() - with self.profiler.profile('model_backward'): - # scale loss for 16 bit - if self.precision == 16 and not self.on_tpu: - closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx) - - # do backward pass - model_ref.backward(self, closure_loss, optimizer, opt_idx) - - # track metrics for callbacks - all_callback_metrics.append(callback_metrics) - - # track progress bar metrics - self.add_progress_bar_metrics(progress_bar_metrics) - all_log_metrics.append(log_metrics) - - if self.use_horovod: - # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid - optimizer.synchronize() - - # insert after step hook - if self.is_function_implemented('on_after_backward'): - model_ref = self.get_model() - with self.profiler.profile('on_after_backward'): - model_ref.on_after_backward() + # ------------------- + # calculate loss + # ------------------- + opt_closure_result = self.optimizer_closure( + split_batch, + batch_idx, + opt_idx, + optimizer, + self.hiddens + ) - return closure_loss, callback_metrics + # ------------------------------ + # POST forward bookkeeping + # ------------------------------ + batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics) + self.add_progress_bar_metrics(opt_closure_result.training_step_output.pbar_on_batch_end) - # calculate loss - loss, batch_output = optimizer_closure() + # loss, training_step_output, training_step_output_for_epoch_end, hiddens + self.hiddens = opt_closure_result.hiddens # check if loss or model weights are nan if self.terminate_on_nan: - self.detect_nan_tensors(loss) + self.detect_nan_tensors(opt_closure_result.loss) # track total loss for logging (avoid mem leaks) - self.batch_loss_value.append(loss) + self.batch_loss_value.append(opt_closure_result.loss) + # ------------------------------ + # BACKWARD PASS + # ------------------------------ # gradient update with accumulated gradients if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: - # track gradient norms when requested - if batch_idx % self.row_log_interval == 0: - if float(self.track_grad_norm) > 0: - model = self.get_model() - grad_norm_dic = model.grad_norm( - self.track_grad_norm) - - # clip gradients - if self.use_amp and self.use_native_amp: - self.scaler.unscale_(optimizer) - self.clip_gradients() - - # calls .step(), .zero_grad() - # override function to modify this behavior - model = self.get_model() - with self.profiler.profile('optimizer_step'): - model.optimizer_step(self.current_epoch, batch_idx, - optimizer, opt_idx, - lambda: optimizer_closure()[0]) + # backward + grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer) # calculate running loss for display self.running_loss.append(self.batch_loss_value.mean()) @@ -686,9 +642,116 @@ def optimizer_closure(): all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} # track all metrics for callbacks - self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()}) + self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()}) - return 0, grad_norm_dic, all_log_metrics, batch_output + return 0, grad_norm_dic, all_log_metrics, opt_closure_result.training_step_output_for_epoch_end + + def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): + # ------------------ + # GRAD NORMS + # ------------------ + # track gradient norms when requested + grad_norm_dic = {} + if batch_idx % self.row_log_interval == 0: + if float(self.track_grad_norm) > 0: + model = self.get_model() + grad_norm_dic = model.grad_norm( + self.track_grad_norm) + + # ------------------ + # CLIP GRADS + # ------------------ + if self.use_amp and self.use_native_amp: + self.scaler.unscale_(optimizer) + self.clip_gradients() + + # ------------------ + # .STEP + ZERO_GRAD + # ------------------ + model = self.get_model() + with self.profiler.profile('optimizer_step'): + lambda_closure = lambda: self.optimizer_closure( + split_batch, + batch_idx, + opt_idx, + optimizer, + self.hiddens + )[0] + model.optimizer_step(self.current_epoch, batch_idx, + optimizer, opt_idx, + lambda: lambda_closure) + + return grad_norm_dic + + def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): + """ + wrap the forward step in a closure so second order methods work + """ + # --------------------------- + # FORWARD + # --------------------------- + with self.profiler.profile('model_forward'): + if self.use_amp and self.use_native_amp: + with torch.cuda.amp.autocast(): + training_step_output = self.training_forward(split_batch, batch_idx, + opt_idx, hiddens) + else: + training_step_output = self.training_forward(split_batch, batch_idx, opt_idx, + hiddens) + + # ---------------------------- + # PROCESS THE RESULT + # ---------------------------- + # format and reduce outputs accordingly + training_step_output = self.process_output(training_step_output, train=True) + + # TODO: temporary part of structured results PR + training_step_output = AttributeDict( + batch_loss=training_step_output[0], + pbar_on_batch_end=training_step_output[1], + log_metrics=training_step_output[2], + callback_metrics=training_step_output[3], + hiddens=training_step_output[4], + ) + + # if the user decides to finally reduce things in epoch_end, save raw output without graphs + training_step_output_for_epoch_end = recursive_detach(training_step_output) + + # accumulate loss + # (if accumulate_grad_batches = 1 no effect) + closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches + + # backward pass + model_ref = self.get_model() + with self.profiler.profile('model_backward'): + # scale loss for 16 bit + if self.precision == 16 and not self.on_tpu: + closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx) + + # do backward pass + model_ref.backward(self, closure_loss, optimizer, opt_idx) + + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + training_step_output.batch_loss = training_step_output.batch_loss.detach() + + if self.use_horovod: + # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid + optimizer.synchronize() + + # insert after step hook + if self.is_function_implemented('on_after_backward'): + model_ref = self.get_model() + with self.profiler.profile('on_after_backward'): + model_ref.on_after_backward() + + result = AttributeDict( + loss=closure_loss, + training_step_output=training_step_output, + training_step_output_for_epoch_end=training_step_output_for_epoch_end, + hiddens=training_step_output.hiddens, + ) + return result def _get_optimizers_iterable(self): if not self.optimizer_frequencies: From 12466eb9323ec0857a885d0f554039916874c72c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 08:42:08 -0400 Subject: [PATCH 02/11] refactored training_bath --- pytorch_lightning/core/step_result.py | 302 +++++++++++++++++++++ pytorch_lightning/trainer/training_loop.py | 7 +- tests/base/deterministic_model.py | 223 +++++++++++++++ tests/trainer/test_trainer_steps.py | 28 ++ 4 files changed, 557 insertions(+), 3 deletions(-) create mode 100644 pytorch_lightning/core/step_result.py create mode 100644 tests/base/deterministic_model.py create mode 100644 tests/trainer/test_trainer_steps.py diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py new file mode 100644 index 0000000000000..c3f3ee02ba124 --- /dev/null +++ b/pytorch_lightning/core/step_result.py @@ -0,0 +1,302 @@ +from typing import Optional, Dict +from torch import Tensor +import torch + + +class Result(Dict): + + def __init__(self, + minimize: Optional[Tensor] = None, + early_stop_on: Tensor = None, + checkpoint_on: Tensor = None, + hiddens: Optional[Tensor] = None): + """ + WIP! Split over many PRs... DO NOT USE YET + + TrainResult is an OrderedDict that gives type hints, allowed fields and validation for bad user input. + + Use as the return value for: + - training_step + + .. note:: Plain dictionary returns are supported but are more prone to errors + + We automatically detach anything here for you to avoid holding references to graphs + + Args: + minimize: Metric to minimize + early_stop_on: Metric for early stopping. Ignored with a validation loop. + checkpoint_on: Metric for checkpointing. Ignored with a validation loop otherwise defaults to `minimize` value. + hiddens: tensor of hiddens to pass to next step when using TBPTT + + .. code-block: python + + # all options: + def training_step(...): + return TrainResult( + minimize=loss, + checkpoint_on=loss, + ) + + # equivalent + return TrainResult(loss) + + # if you have no validation loop, you can still early_stop and/or checkpoint on a metric + # only checkpointing is applied by default here + return TrainResult(loss, early_stop_on=accuracy, checkpoint_on=bleu_score) + + result = TrainResult(loss) + + # logging will log to your logger(s) at the end of the batch + result.log('train_nce_loss', loss) + + # you can log at the end of the batch, or epoch or both + result.log('train_nce_loss', loss, on_batch_end=True, on_epoch_end=False) + + # same thing for the progress bar + result.to_pbar(train_nce_loss', loss) + result.to_pbar('train_nce_loss', loss, on_batch_end=True, on_epoch_end=False) + + Although 99% of the time we are interested in a metric for each training batch, (ie: loss decrease over the epoch), + sometimes you may want to know something like the average loss for the full epoch. You can either + define the `training_epoch_end` method for something fancy, or use the `on_epoch_end` argument with your custom + reduce function + + .. code-block: python + + # maybe sum `log_probs` across all the training batches + result.log('log_probs', log_probs, reduce_fx=torch.sum) + + # or do something weird to `log_probs` across all the training batches + def my_weird_reduction(all_log_probs): + all_log_probs = F.softmax(torch.cat(all_log_probs), dim=1) + return all_log_probs + + result.log('log_probs', log_probs, reduce_fx=my_weird_reduction) + """ + + super().__init__() + + self.early_stop_on = early_stop_on + self.checkpoint_on = checkpoint_on + + # TODO: should hiddens detach? + self.hiddens = hiddens + self.minimize = minimize + + @classmethod + def union(cls, outputs, result=None): + if result is None: + result = Result() + + for out in outputs: + for k, v in out.items(): + if k in ['reduce_fx_on_epoch_end']: + continue + + if k not in result and isinstance(v, (dict, Result)): + result[k] = Result() + + if isinstance(v, dict): + v = cls.union([v], result[k]) + + if isinstance(v, list) and len(v) == 1: + v = v[0] + result[k] = v + + return result + + @classmethod + def from_result_dict(cls, dict_result, trainer): + result = Result() + + if 'log' in dict_result: + result.log_metrics(dict_result['log']) + if 'progress_bar' in dict_result: + result.pbar_metrics(dict_result['progress_bar']) + + # add the early stop metric + if trainer.early_stop_callback is not None: + early_stop_metric = trainer.early_stop_callback.monitor + if early_stop_metric in dict_result: + result.early_stop_on = dict_result[early_stop_metric] + + # add the checkpoint metric + if trainer.checkpoint_callback is not None: + checkpoint_metric = trainer.checkpoint_callback.monitor + if checkpoint_metric in dict_result: + result.checkpoint_on = dict_result[checkpoint_metric] + + return result + + def __reduce_on_callback(self, callback_name, name, metric, log, pbar, reduce_fx): + assert isinstance(metric, torch.Tensor), f'{name} must be a torch.Tensor' + + keys = [f'reduce_{callback_name}'] + if log: + keys.append(f'log_{callback_name}') + if pbar: + keys.append(f'pbar_{callback_name}') + + for key in keys: + if key not in self: + self[key] = {} + + if 'log' in key or 'pbar' in key: + metric = metric.detach() + + metrics = self[key] + metrics[name] = metric + + key = f'reduce_fx_{callback_name}' + if key not in self: + self[key] = {} + + metrics = self[key] + metrics[name] = reduce_fx + + def pbar_metric(self, name: str, value: Tensor, on_batch_end=False, on_epoch_end=True, reduce_fx=torch.mean): + if on_batch_end: + self.__reduce_on_callback('on_batch_end', name, value, log=False, pbar=True, reduce_fx=reduce_fx) + if on_epoch_end: + self.__reduce_on_callback('on_epoch_end', name, value, log=False, pbar=True, reduce_fx=reduce_fx) + + def pbar_metrics(self, values: dict, on_batch_end=False, on_epoch_end=True, reduce_fx=torch.mean): + for name, value in values.items(): + if on_batch_end: + self.__reduce_on_callback('on_batch_end', name, value, log=False, pbar=True, reduce_fx=reduce_fx) + if on_epoch_end: + self.__reduce_on_callback('on_epoch_end', name, value, log=False, pbar=True, reduce_fx=reduce_fx) + + def log_metric(self, name: str, value: Tensor, on_batch_end=False, on_epoch_end=True, reduce_fx=torch.mean): + if on_batch_end: + self.__reduce_on_callback('on_batch_end', name, value, log=True, pbar=False, reduce_fx=reduce_fx) + if on_epoch_end: + self.__reduce_on_callback('on_epoch_end', name, value, log=True, pbar=False, reduce_fx=reduce_fx) + + def log_metrics(self, values: dict, on_batch_end=False, on_epoch_end=True, reduce_fx=torch.mean): + for name, value in values.items(): + if on_batch_end: + self.__reduce_on_callback('on_batch_end', name, value, log=True, pbar=False, reduce_fx=reduce_fx) + if on_epoch_end: + self.__reduce_on_callback('on_epoch_end', name, value, log=True, pbar=False, reduce_fx=reduce_fx) + + @property + def log_on_batch_end(self): + return self.__getitem__('log_on_batch_end') + + @log_on_batch_end.setter + def log_on_batch_end(self, x): + if x is not None: + assert isinstance(x, dict), 'log_on_batch_end must be a dict' + self.__setitem__('log_on_batch_end', x) + + @property + def pbar_on_batch_end(self): + return self.__getitem__('pbar_on_batch_end') + + @pbar_on_batch_end.setter + def pbar_on_batch_end(self, x): + if x is not None: + assert isinstance(x, dict), 'pbar_on_batch_end must be a dict' + self.__setitem__('pbar_on_batch_end', x) + + @property + def log_on_epoch_end(self): + return self.__getitem__('log_on_epoch_end') + + @log_on_epoch_end.setter + def log_on_epoch_end(self, x): + if x is not None: + assert isinstance(x, dict), 'log_on_epoch_end must be a dict' + self.__setitem__('log_on_epoch_end', x) + + @property + def pbar_on_epoch_end(self): + return self.__getitem__('pbar_on_epoch_end') + + @pbar_on_epoch_end.setter + def pbar_on_epoch_end(self, x): + if x is not None: + assert isinstance(x, dict), 'pbar_on_epoch_end must be a dict' + self.__setitem__('pbar_on_epoch_end', x) + + @property + def progress_bar(self): + return self.__getitem__('progress_bar') + + @progress_bar.setter + def progress_bar(self, x): + if x is not None: + assert isinstance(x, dict), 'progress_bar_logs must be a dict' + self.__setitem__('progress_bar', x) + + @property + def logs(self): + return self.__getitem__('logs') + + @logs.setter + def logs(self, x): + if x is not None: + assert isinstance(x, dict), 'logs must be a dict' + self.__setitem__('logs', x) + + @property + def hiddens(self): + return self._hiddens + + @hiddens.setter + def hiddens(self, x): + if x is not None: + assert isinstance(x, Tensor), 'hiddens must be a torch.Tensor' + self._hiddens = x + self.__setitem__('hiddens', x) + + @property + def checkpoint_on(self): + # use minimize as default if no checkpoint_on is passed + if 'checkpoint_on' not in self: + minimize = self.__getitem__('minimize') + self.__setitem__('checkpoint_on', minimize) + + return self.__getitem__('checkpoint_on') + + @checkpoint_on.setter + def checkpoint_on(self, x): + if x is not None: + assert isinstance(x, Tensor), 'checkpoint_on must be a torch.Tensor' + self.__setitem__('checkpoint_on', x.detach()) + + @property + def early_stop_on(self): + # use minimize as default if no checkpoint_on is passed + if 'early_stop_on' not in self: + minimize = self.__getitem__('minimize') + self.__setitem__('early_stop_on', minimize) + + return self.__getitem__('early_stop_on') + + @early_stop_on.setter + def early_stop_on(self, x): + if x is not None: + assert isinstance(x, Tensor), 'early_stop_on must be a torch.Tensor' + self.__setitem__('early_stop_on', x.detach()) + + @property + def minimize(self): + return self.__getitem__('minimize') + + @minimize.setter + def minimize(self, x): + if x is not None: + assert isinstance(x, Tensor), 'metric to minimize must be a torch.Tensor' + m = 'the metric to minimize must have a computational graph. Minimize ' \ + 'can only be used in training_end, training_step_end, training_epoch_end' + assert x.grad_fn is not None, m + self.__setitem__('minimize', x) + + +if __name__ == '__main__': + import torch + result = Result() + result.log_metrics({'a': 2}) + result.minimize = torch.tensor(1) \ No newline at end of file diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 10c533edc9922..c8a6f62e25da4 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -553,7 +553,7 @@ def run_training_batch(self, batch, batch_idx): batch_callback_metrics = [] # track metrics to log - all_log_metrics = [] + batch_log_metrics = [] if batch is None: return 0, grad_norm_dic, {}, {} @@ -603,6 +603,7 @@ def run_training_batch(self, batch, batch_idx): # POST forward bookkeeping # ------------------------------ batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics) + batch_log_metrics.append(opt_closure_result.training_step_output.log_metrics) self.add_progress_bar_metrics(opt_closure_result.training_step_output.pbar_on_batch_end) # loss, training_step_output, training_step_output_for_epoch_end, hiddens @@ -639,12 +640,12 @@ def run_training_batch(self, batch, batch_idx): self.get_model().on_batch_end() # collapse all metrics into one dict - all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} + batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} # track all metrics for callbacks self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()}) - return 0, grad_norm_dic, all_log_metrics, opt_closure_result.training_step_output_for_epoch_end + return 0, grad_norm_dic, batch_log_metrics, opt_closure_result.training_step_output_for_epoch_end def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): # ------------------ diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py new file mode 100644 index 0000000000000..b5f422ea5d97a --- /dev/null +++ b/tests/base/deterministic_model.py @@ -0,0 +1,223 @@ +import torch +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.step_result import Result +from torch.utils.data import Dataset, DataLoader +import numpy as np + + +class DeterministicModel(LightningModule): + + def __init__(self, weights=None): + super().__init__() + if weights is None: + weights = torch.tensor([ + [4, 3, 5], + [10, 11, 13] + ]).float() + self.l1 = torch.nn.Parameter(weights, requires_grad=True) + + def forward(self, x): + return self.l1.mm(x.float().t()) + + def base_train_result(self, acc): + x = acc + result = Result( + minimize=acc, + early_stop_on=torch.tensor(1.4).type_as(x), + checkpoint_on=torch.tensor(1.5).type_as(x) + ) + + result.log_metric('log_acc1', torch.tensor(12).type_as(x)) + result.log_metrics({'log_acc2': torch.tensor(7).type_as(x)}) + result.pbar_metric('pbar_acc1', torch.tensor(17).type_as(x)) + result.pbar_metrics({'pbar_acc2': torch.tensor(19).type_as(x)}) + + # make sure minimize is the only thing with a graph + self.assert_graph_count(result, 1) + return result + + def base_eval_result(self, acc): + x = acc + result = Result( + early_stop_on=torch.tensor(1.4).type_as(x), + checkpoint_on=torch.tensor(1.5).type_as(x) + ) + result.log_metrics({ + 'log_acc1': torch.tensor(12).type_as(x), + 'log_acc2': torch.tensor(7).type_as(x) + }) + result.pbar_metrics({ + 'pbar_acc1': torch.tensor(17).type_as(x), + 'pbar_acc2': torch.tensor(19).type_as(x) + }) + return result + + def step(self, batch, batch_idx): + x = batch + y_hat = self(x) + + assert torch.all(y_hat[0, :] == 15.0) + assert torch.all(y_hat[1, :] == 42.0) + out = y_hat.sum() + assert out == (42.0*3) + (15.0*3) + + return out + + def assert_graph_count(self, result, count=1): + counts = self.count_num_graphs(result) + assert counts == count + + def count_num_graphs(self, result: Result, num_graphs=0): + for k, v in result.items(): + if isinstance(v, torch.Tensor) and v.grad_fn is not None: + num_graphs += 1 + if isinstance(v, dict): + num_graphs += self.count_num_graphs(v) + + return num_graphs + + def training_step_only(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + result = self.base_train_result(acc) + return result + + def training_step_with_batch_end(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + result = self.base_train_result(acc) + + return result + + def training_step_with_epoch_end(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + result = self.base_train_result(acc) + result.pass_to_epoch_end('to_epoch_end_1', torch.tensor([-3, -2, -3]).type_as(acc)) + + return result + + def training_step_with_batch_and_epoch_end(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + result = self.base_train_result(acc) + result.pass_to_batch_end('to_batch_end_1', torch.tensor([-1, -2, -3]).type_as(acc)) + result.pass_to_epoch_end('to_epoch_end_1', torch.tensor([-3, -2, -3]).type_as(acc)) + + return result + + def training_step_dict_return(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} + pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} + return {'loss': acc, 'log': logs, 'progress_bar': pbar} + + def training_step_end_basic(self, outputs): + # make sure only the expected keys are here + keys = set(outputs.keys()) + assert keys == {'to_batch_end_1', 'minimize'} + + result = Result() + result.pass_to_epoch_end('from_train_step_end', torch.tensor(19)) + return result + + def training_epoch_end_basic(self, outputs): + if self.use_dp or self.use_ddp2: + pass + else: + # only saw 3 batches + assert len(outputs) == 3 + for batch_out in outputs: + assert len(batch_out.keys()) == 2 + keys = ['to_batch_end_1', 'to_batch_end_2'] + for key in keys: + assert key in batch_out + + def validation_step_only(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + result = self.base_eval_result(acc) + + return result + + def validation_step_with_batch_end(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + result = self.base_eval_result(acc) + result.pass_to_batch_end('to_batch_end_1', torch.tensor([-1, -2, -3]).type_as(acc)) + + return result + + def validation_step_with_epoch_end(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + result = self.base_eval_result(acc) + result.pass_to_epoch_end('to_epoch_end_1', torch.tensor([-3, -2, -3]).type_as(acc)) + + return result + + def validation_step_with_batch_and_epoch_end(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + result = self.base_eval_result(acc) + result.pass_to_batch_end('to_batch_end_1', torch.tensor([-1, -2, -3]).type_as(acc)) + result.pass_to_epoch_end('to_epoch_end_1', torch.tensor([-3, -2, -3]).type_as(acc)) + + return result + + def validation_step_dict_return(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} + pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} + return {'val_loss': acc, 'log': logs, 'progress_bar': pbar} + + def validation_step_end_basic(self, outputs): + if self.use_dp or self.use_ddp2: + pass + else: + # only saw 3 batches + assert len(outputs) == 3 + for batch_out in outputs: + assert len(batch_out.keys()) == 2 + keys = ['to_batch_end_1', 'to_batch_end_2', 'minimize'] + for key in keys: + assert key in batch_out + + result = TrainResult() + result.pass_to_epoch_end('from_train_step_end', torch.tensor(19)) + + def validation_epoch_end_basic(self, outputs): + if self.use_dp or self.use_ddp2: + pass + else: + # only saw 3 batches + assert len(outputs) == 3 + for batch_out in outputs: + assert len(batch_out.keys()) == 2 + keys = ['to_batch_end_1', 'to_batch_end_2'] + for key in keys: + assert key in batch_out + + def train_dataloader(self): + return DataLoader(DummyDataset(), batch_size=3, shuffle=False) + + def val_dataloader(self): + return DataLoader(DummyDataset(), batch_size=3, shuffle=False) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0) + + def backward(self, trainer, loss, optimizer, optimizer_idx): + assert loss == 171.0 + loss.backward() + + +class DummyDataset(Dataset): + + def __len__(self): + return 12 + + def __getitem__(self, idx): + return np.array([0.5, 1.0, 2.0]) \ No newline at end of file diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py new file mode 100644 index 0000000000000..ada9481ecf84d --- /dev/null +++ b/tests/trainer/test_trainer_steps.py @@ -0,0 +1,28 @@ +from pytorch_lightning import Trainer +from tests.base.deterministic_model import DeterministicModel + + +def test_trainingstep_dict(tmpdir): + """ + Tests that only training_step can be used + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.val_dataloader = None + + trainer = Trainer(fast_dev_run=True, weights_summary=None) + trainer.fit(model) + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out + assert signal == 0 + assert all_log_metrics['log_acc1'] == 12.0 + assert all_log_metrics['log_acc2'] == 7.0 + + pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + assert pbar_metrics['pbar_acc1'] == 17.0 + assert pbar_metrics['pbar_acc2'] == 19.0 From 2ea8e7f8db64171c6f8d865e5ecaa278f24dcab3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 09:12:16 -0400 Subject: [PATCH 03/11] refactored training_bath --- tests/base/deterministic_model.py | 160 +++++++--------------------- tests/trainer/test_trainer_steps.py | 39 +++++++ 2 files changed, 76 insertions(+), 123 deletions(-) diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py index b5f422ea5d97a..c63650caade8e 100644 --- a/tests/base/deterministic_model.py +++ b/tests/base/deterministic_model.py @@ -9,6 +9,11 @@ class DeterministicModel(LightningModule): def __init__(self, weights=None): super().__init__() + + self.training_step_called = False + self.training_step_end_called = False + self.training_epoch_end_called = False + if weights is None: weights = torch.tensor([ [4, 3, 5], @@ -19,39 +24,6 @@ def __init__(self, weights=None): def forward(self, x): return self.l1.mm(x.float().t()) - def base_train_result(self, acc): - x = acc - result = Result( - minimize=acc, - early_stop_on=torch.tensor(1.4).type_as(x), - checkpoint_on=torch.tensor(1.5).type_as(x) - ) - - result.log_metric('log_acc1', torch.tensor(12).type_as(x)) - result.log_metrics({'log_acc2': torch.tensor(7).type_as(x)}) - result.pbar_metric('pbar_acc1', torch.tensor(17).type_as(x)) - result.pbar_metrics({'pbar_acc2': torch.tensor(19).type_as(x)}) - - # make sure minimize is the only thing with a graph - self.assert_graph_count(result, 1) - return result - - def base_eval_result(self, acc): - x = acc - result = Result( - early_stop_on=torch.tensor(1.4).type_as(x), - checkpoint_on=torch.tensor(1.5).type_as(x) - ) - result.log_metrics({ - 'log_acc1': torch.tensor(12).type_as(x), - 'log_acc2': torch.tensor(7).type_as(x) - }) - result.pbar_metrics({ - 'pbar_acc1': torch.tensor(17).type_as(x), - 'pbar_acc2': torch.tensor(19).type_as(x) - }) - return result - def step(self, batch, batch_idx): x = batch y_hat = self(x) @@ -76,51 +48,52 @@ def count_num_graphs(self, result: Result, num_graphs=0): return num_graphs - def training_step_only(self, batch, batch_idx): - acc = self.step(batch, batch_idx) - - result = self.base_train_result(acc) - return result - - def training_step_with_batch_end(self, batch, batch_idx): + # -------------------------- + # dictionary returns + # -------------------------- + def training_step_dict_return(self, batch, batch_idx): acc = self.step(batch, batch_idx) - result = self.base_train_result(acc) + logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} + pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} - return result + self.training_step_called = True + return {'loss': acc, 'log': logs, 'progress_bar': pbar} - def training_step_with_epoch_end(self, batch, batch_idx): + def training_step_for_step_end_dict(self, batch, batch_idx): + """sends outputs to training_batch_end""" acc = self.step(batch, batch_idx) - result = self.base_train_result(acc) - result.pass_to_epoch_end('to_epoch_end_1', torch.tensor([-3, -2, -3]).type_as(acc)) + logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} + pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} + self.training_step_called = True + result = {'loss': acc} + result.update(logs) + result.update(pbar) return result - def training_step_with_batch_and_epoch_end(self, batch, batch_idx): - acc = self.step(batch, batch_idx) + def training_step_end_dict(self, output): + self.training_step_end_called = True - result = self.base_train_result(acc) - result.pass_to_batch_end('to_batch_end_1', torch.tensor([-1, -2, -3]).type_as(acc)) - result.pass_to_epoch_end('to_epoch_end_1', torch.tensor([-3, -2, -3]).type_as(acc)) + # make sure loss has the grad + assert 'loss' in output + assert output['loss'].grad_fn is not None - return result + # make sure nothing else has grads + assert self.count_num_graphs(output) == 1 - def training_step_dict_return(self, batch, batch_idx): - acc = self.step(batch, batch_idx) + # make sure the other keys are there + assert 'log_acc1' in output + assert 'log_acc2' in output + assert 'pbar_acc1' in output + assert 'pbar_acc2' in output - logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} - pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} - return {'loss': acc, 'log': logs, 'progress_bar': pbar} - - def training_step_end_basic(self, outputs): - # make sure only the expected keys are here - keys = set(outputs.keys()) - assert keys == {'to_batch_end_1', 'minimize'} + logs = {'log_acc1': output['log_acc1'], 'log_acc2': output['log_acc2']} + pbar = {'pbar_acc1': output['pbar_acc1'], 'pbar_acc2': output['pbar_acc2']} - result = Result() - result.pass_to_epoch_end('from_train_step_end', torch.tensor(19)) - return result + acc = output['loss'] + return {'loss': acc, 'log': logs, 'progress_bar': pbar} def training_epoch_end_basic(self, outputs): if self.use_dp or self.use_ddp2: @@ -134,38 +107,6 @@ def training_epoch_end_basic(self, outputs): for key in keys: assert key in batch_out - def validation_step_only(self, batch, batch_idx): - acc = self.step(batch, batch_idx) - - result = self.base_eval_result(acc) - - return result - - def validation_step_with_batch_end(self, batch, batch_idx): - acc = self.step(batch, batch_idx) - - result = self.base_eval_result(acc) - result.pass_to_batch_end('to_batch_end_1', torch.tensor([-1, -2, -3]).type_as(acc)) - - return result - - def validation_step_with_epoch_end(self, batch, batch_idx): - acc = self.step(batch, batch_idx) - - result = self.base_eval_result(acc) - result.pass_to_epoch_end('to_epoch_end_1', torch.tensor([-3, -2, -3]).type_as(acc)) - - return result - - def validation_step_with_batch_and_epoch_end(self, batch, batch_idx): - acc = self.step(batch, batch_idx) - - result = self.base_eval_result(acc) - result.pass_to_batch_end('to_batch_end_1', torch.tensor([-1, -2, -3]).type_as(acc)) - result.pass_to_epoch_end('to_epoch_end_1', torch.tensor([-3, -2, -3]).type_as(acc)) - - return result - def validation_step_dict_return(self, batch, batch_idx): acc = self.step(batch, batch_idx) @@ -173,33 +114,6 @@ def validation_step_dict_return(self, batch, batch_idx): pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} return {'val_loss': acc, 'log': logs, 'progress_bar': pbar} - def validation_step_end_basic(self, outputs): - if self.use_dp or self.use_ddp2: - pass - else: - # only saw 3 batches - assert len(outputs) == 3 - for batch_out in outputs: - assert len(batch_out.keys()) == 2 - keys = ['to_batch_end_1', 'to_batch_end_2', 'minimize'] - for key in keys: - assert key in batch_out - - result = TrainResult() - result.pass_to_epoch_end('from_train_step_end', torch.tensor(19)) - - def validation_epoch_end_basic(self, outputs): - if self.use_dp or self.use_ddp2: - pass - else: - # only saw 3 batches - assert len(outputs) == 3 - for batch_out in outputs: - assert len(batch_out.keys()) == 2 - keys = ['to_batch_end_1', 'to_batch_end_2'] - for key in keys: - assert key in batch_out - def train_dataloader(self): return DataLoader(DummyDataset(), batch_size=3, shuffle=False) diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index ada9481ecf84d..0024e5d04f9af 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -13,6 +13,11 @@ def test_trainingstep_dict(tmpdir): trainer = Trainer(fast_dev_run=True, weights_summary=None) trainer.fit(model) + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + # make sure training outputs what is expected for batch_idx, batch in enumerate(model.train_dataloader()): break @@ -26,3 +31,37 @@ def test_trainingstep_dict(tmpdir): pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 assert pbar_metrics['pbar_acc2'] == 19.0 + + +def training_step_with_step_end(tmpdir): + """ + Checks train_step + training_step_end + """ + model = DeterministicModel() + model.training_step = model.training_step_for_step_end_dict + model.training_step_end = model.training_step_end_dict + model.val_dataloader = None + + trainer = Trainer(fast_dev_run=True, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out + assert signal == 0 + assert all_log_metrics['log_acc1'] == 12.0 + assert all_log_metrics['log_acc2'] == 7.0 + + pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + assert pbar_metrics['pbar_acc1'] == 17.0 + assert pbar_metrics['pbar_acc2'] == 19.0 + + From 0345b202c2cc8fc6bc1d440bebc617b5454fb4f1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 09:12:51 -0400 Subject: [PATCH 04/11] refactored training_bath --- tests/trainer/test_trainer_steps.py | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index 0024e5d04f9af..8ff73df0eb11e 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -65,3 +65,33 @@ def training_step_with_step_end(tmpdir): assert pbar_metrics['pbar_acc2'] == 19.0 +def test_full_training_loop_dict(tmpdir): + """ + Checks train_step + training_step_end + training_epoch_end + """ + model = DeterministicModel() + model.training_step = model.training_step_for_step_end_dict + model.training_step_end = model.training_step_end_dict + model.val_dataloader = None + + trainer = Trainer(fast_dev_run=True, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out + assert signal == 0 + assert all_log_metrics['log_acc1'] == 12.0 + assert all_log_metrics['log_acc2'] == 7.0 + + pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + assert pbar_metrics['pbar_acc1'] == 17.0 + assert pbar_metrics['pbar_acc2'] == 19.0 \ No newline at end of file From 0ef94a63f9e9d627bd79e1cef411836c91174f62 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 09:37:58 -0400 Subject: [PATCH 05/11] refactored training_bath --- pytorch_lightning/core/step_result.py | 2 +- pytorch_lightning/trainer/training_loop.py | 6 ++++++ tests/base/deterministic_model.py | 22 +++++++++++++++------- tests/trainer/test_trainer_steps.py | 14 +++++++++++--- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index c3f3ee02ba124..a4641a54d4a9f 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -299,4 +299,4 @@ def minimize(self, x): import torch result = Result() result.log_metrics({'a': 2}) - result.minimize = torch.tensor(1) \ No newline at end of file + result.minimize = torch.tensor(1) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c8a6f62e25da4..fd9e7f365a611 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -529,8 +529,14 @@ def run_training_epoch(self): _processed_outputs = self.process_output(epoch_output) log_epoch_metrics = _processed_outputs[2] callback_epoch_metrics = _processed_outputs[3] + + # add the metrics to the loggers self.log_metrics(log_epoch_metrics, {}) + + # add metrics to callbacks self.callback_metrics.update(callback_epoch_metrics) + + # add metrics to progress_bar self.add_progress_bar_metrics(_processed_outputs[1]) # when no val loop is present or fast-dev-run still need to call checkpoints diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py index c63650caade8e..dc4972ab2b224 100644 --- a/tests/base/deterministic_model.py +++ b/tests/base/deterministic_model.py @@ -31,7 +31,7 @@ def step(self, batch, batch_idx): assert torch.all(y_hat[0, :] == 15.0) assert torch.all(y_hat[1, :] == 42.0) out = y_hat.sum() - assert out == (42.0*3) + (15.0*3) + assert out == (42.0 * 3) + (15.0 * 3) return out @@ -95,18 +95,26 @@ def training_step_end_dict(self, output): acc = output['loss'] return {'loss': acc, 'log': logs, 'progress_bar': pbar} - def training_epoch_end_basic(self, outputs): + def training_epoch_end_dict(self, outputs): + self.training_epoch_end_called = True + if self.use_dp or self.use_ddp2: pass else: - # only saw 3 batches - assert len(outputs) == 3 + # only saw 4 batches + assert len(outputs) == 4 for batch_out in outputs: - assert len(batch_out.keys()) == 2 - keys = ['to_batch_end_1', 'to_batch_end_2'] + assert len(batch_out.keys()) == 5 + keys = ['batch_loss', 'pbar_on_batch_end', 'log_metrics', 'callback_metrics'] for key in keys: assert key in batch_out + prototype_loss = outputs[0]['batch_loss'] + logs = {'epoch_end_log_1': torch.tensor(178).type_as(prototype_loss)} + pbar = {'epoch_end_pbar_1': torch.tensor(234).type_as(prototype_loss)} + + return {'log': logs, 'progress_bar': pbar} + def validation_step_dict_return(self, batch, batch_idx): acc = self.step(batch, batch_idx) @@ -134,4 +142,4 @@ def __len__(self): return 12 def __getitem__(self, idx): - return np.array([0.5, 1.0, 2.0]) \ No newline at end of file + return np.array([0.5, 1.0, 2.0]) diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index 8ff73df0eb11e..c8a1f8349d58d 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -72,15 +72,20 @@ def test_full_training_loop_dict(tmpdir): model = DeterministicModel() model.training_step = model.training_step_for_step_end_dict model.training_step_end = model.training_step_end_dict + model.training_epoch_end = model.training_epoch_end_dict model.val_dataloader = None - trainer = Trainer(fast_dev_run=True, weights_summary=None) + trainer = Trainer(max_epochs=1, weights_summary=None) trainer.fit(model) # make sure correct steps were called assert model.training_step_called assert model.training_step_end_called - assert not model.training_epoch_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert trainer.callback_metrics['epoch_end_log_1'] == 178 + assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234 # make sure training outputs what is expected for batch_idx, batch in enumerate(model.train_dataloader()): @@ -94,4 +99,7 @@ def test_full_training_loop_dict(tmpdir): pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 - assert pbar_metrics['pbar_acc2'] == 19.0 \ No newline at end of file + assert pbar_metrics['pbar_acc2'] == 19.0 + + +test_full_training_loop_dict('') From 697fc6b3bbbd7f047213306e8dd55cf8daca5e86 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 09:44:37 -0400 Subject: [PATCH 06/11] refactored training_bath --- tests/trainer/test_trainer_steps.py | 36 ++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index c8a1f8349d58d..7e23324eed192 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -102,4 +102,38 @@ def test_full_training_loop_dict(tmpdir): assert pbar_metrics['pbar_acc2'] == 19.0 -test_full_training_loop_dict('') +def test_train_step_epoch_end(tmpdir): + """ + Checks train_step + training_epoch_end (NO training_step_end) + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.training_step_end = None + model.training_epoch_end = model.training_epoch_end_dict + model.val_dataloader = None + + trainer = Trainer(max_epochs=1, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert trainer.callback_metrics['epoch_end_log_1'] == 178 + assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234 + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out + assert signal == 0 + assert all_log_metrics['log_acc1'] == 12.0 + assert all_log_metrics['log_acc2'] == 7.0 + + pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end'] + assert pbar_metrics['pbar_acc1'] == 17.0 + assert pbar_metrics['pbar_acc2'] == 19.0 From a56d9f63093008e23f95e963be652c09c420e62a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 09:45:38 -0400 Subject: [PATCH 07/11] refactored training_bath --- pytorch_lightning/core/step_result.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index a4641a54d4a9f..15f9979ee50d1 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -25,7 +25,8 @@ def __init__(self, Args: minimize: Metric to minimize early_stop_on: Metric for early stopping. Ignored with a validation loop. - checkpoint_on: Metric for checkpointing. Ignored with a validation loop otherwise defaults to `minimize` value. + checkpoint_on: Metric for checkpointing. Ignored with a validation loop otherwise defaults + to `minimize` value. hiddens: tensor of hiddens to pass to next step when using TBPTT .. code-block: python @@ -56,10 +57,10 @@ def training_step(...): result.to_pbar(train_nce_loss', loss) result.to_pbar('train_nce_loss', loss, on_batch_end=True, on_epoch_end=False) - Although 99% of the time we are interested in a metric for each training batch, (ie: loss decrease over the epoch), - sometimes you may want to know something like the average loss for the full epoch. You can either - define the `training_epoch_end` method for something fancy, or use the `on_epoch_end` argument with your custom - reduce function + Although 99% of the time we are interested in a metric for each training batch, + (ie: loss decrease over the epoch), sometimes you may want to know something like the average loss + for the full epoch. You can either define the `training_epoch_end` method for something fancy, + or use the `on_epoch_end` argument with your custom reduce function .. code-block: python From 012dbe320d3c94fe622d5225df2ee91c33830f3e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 09:46:35 -0400 Subject: [PATCH 08/11] refactored training_bath --- pytorch_lightning/core/step_result.py | 303 -------------------------- tests/base/deterministic_model.py | 3 +- 2 files changed, 1 insertion(+), 305 deletions(-) delete mode 100644 pytorch_lightning/core/step_result.py diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py deleted file mode 100644 index 15f9979ee50d1..0000000000000 --- a/pytorch_lightning/core/step_result.py +++ /dev/null @@ -1,303 +0,0 @@ -from typing import Optional, Dict -from torch import Tensor -import torch - - -class Result(Dict): - - def __init__(self, - minimize: Optional[Tensor] = None, - early_stop_on: Tensor = None, - checkpoint_on: Tensor = None, - hiddens: Optional[Tensor] = None): - """ - WIP! Split over many PRs... DO NOT USE YET - - TrainResult is an OrderedDict that gives type hints, allowed fields and validation for bad user input. - - Use as the return value for: - - training_step - - .. note:: Plain dictionary returns are supported but are more prone to errors - - We automatically detach anything here for you to avoid holding references to graphs - - Args: - minimize: Metric to minimize - early_stop_on: Metric for early stopping. Ignored with a validation loop. - checkpoint_on: Metric for checkpointing. Ignored with a validation loop otherwise defaults - to `minimize` value. - hiddens: tensor of hiddens to pass to next step when using TBPTT - - .. code-block: python - - # all options: - def training_step(...): - return TrainResult( - minimize=loss, - checkpoint_on=loss, - ) - - # equivalent - return TrainResult(loss) - - # if you have no validation loop, you can still early_stop and/or checkpoint on a metric - # only checkpointing is applied by default here - return TrainResult(loss, early_stop_on=accuracy, checkpoint_on=bleu_score) - - result = TrainResult(loss) - - # logging will log to your logger(s) at the end of the batch - result.log('train_nce_loss', loss) - - # you can log at the end of the batch, or epoch or both - result.log('train_nce_loss', loss, on_batch_end=True, on_epoch_end=False) - - # same thing for the progress bar - result.to_pbar(train_nce_loss', loss) - result.to_pbar('train_nce_loss', loss, on_batch_end=True, on_epoch_end=False) - - Although 99% of the time we are interested in a metric for each training batch, - (ie: loss decrease over the epoch), sometimes you may want to know something like the average loss - for the full epoch. You can either define the `training_epoch_end` method for something fancy, - or use the `on_epoch_end` argument with your custom reduce function - - .. code-block: python - - # maybe sum `log_probs` across all the training batches - result.log('log_probs', log_probs, reduce_fx=torch.sum) - - # or do something weird to `log_probs` across all the training batches - def my_weird_reduction(all_log_probs): - all_log_probs = F.softmax(torch.cat(all_log_probs), dim=1) - return all_log_probs - - result.log('log_probs', log_probs, reduce_fx=my_weird_reduction) - """ - - super().__init__() - - self.early_stop_on = early_stop_on - self.checkpoint_on = checkpoint_on - - # TODO: should hiddens detach? - self.hiddens = hiddens - self.minimize = minimize - - @classmethod - def union(cls, outputs, result=None): - if result is None: - result = Result() - - for out in outputs: - for k, v in out.items(): - if k in ['reduce_fx_on_epoch_end']: - continue - - if k not in result and isinstance(v, (dict, Result)): - result[k] = Result() - - if isinstance(v, dict): - v = cls.union([v], result[k]) - - if isinstance(v, list) and len(v) == 1: - v = v[0] - result[k] = v - - return result - - @classmethod - def from_result_dict(cls, dict_result, trainer): - result = Result() - - if 'log' in dict_result: - result.log_metrics(dict_result['log']) - if 'progress_bar' in dict_result: - result.pbar_metrics(dict_result['progress_bar']) - - # add the early stop metric - if trainer.early_stop_callback is not None: - early_stop_metric = trainer.early_stop_callback.monitor - if early_stop_metric in dict_result: - result.early_stop_on = dict_result[early_stop_metric] - - # add the checkpoint metric - if trainer.checkpoint_callback is not None: - checkpoint_metric = trainer.checkpoint_callback.monitor - if checkpoint_metric in dict_result: - result.checkpoint_on = dict_result[checkpoint_metric] - - return result - - def __reduce_on_callback(self, callback_name, name, metric, log, pbar, reduce_fx): - assert isinstance(metric, torch.Tensor), f'{name} must be a torch.Tensor' - - keys = [f'reduce_{callback_name}'] - if log: - keys.append(f'log_{callback_name}') - if pbar: - keys.append(f'pbar_{callback_name}') - - for key in keys: - if key not in self: - self[key] = {} - - if 'log' in key or 'pbar' in key: - metric = metric.detach() - - metrics = self[key] - metrics[name] = metric - - key = f'reduce_fx_{callback_name}' - if key not in self: - self[key] = {} - - metrics = self[key] - metrics[name] = reduce_fx - - def pbar_metric(self, name: str, value: Tensor, on_batch_end=False, on_epoch_end=True, reduce_fx=torch.mean): - if on_batch_end: - self.__reduce_on_callback('on_batch_end', name, value, log=False, pbar=True, reduce_fx=reduce_fx) - if on_epoch_end: - self.__reduce_on_callback('on_epoch_end', name, value, log=False, pbar=True, reduce_fx=reduce_fx) - - def pbar_metrics(self, values: dict, on_batch_end=False, on_epoch_end=True, reduce_fx=torch.mean): - for name, value in values.items(): - if on_batch_end: - self.__reduce_on_callback('on_batch_end', name, value, log=False, pbar=True, reduce_fx=reduce_fx) - if on_epoch_end: - self.__reduce_on_callback('on_epoch_end', name, value, log=False, pbar=True, reduce_fx=reduce_fx) - - def log_metric(self, name: str, value: Tensor, on_batch_end=False, on_epoch_end=True, reduce_fx=torch.mean): - if on_batch_end: - self.__reduce_on_callback('on_batch_end', name, value, log=True, pbar=False, reduce_fx=reduce_fx) - if on_epoch_end: - self.__reduce_on_callback('on_epoch_end', name, value, log=True, pbar=False, reduce_fx=reduce_fx) - - def log_metrics(self, values: dict, on_batch_end=False, on_epoch_end=True, reduce_fx=torch.mean): - for name, value in values.items(): - if on_batch_end: - self.__reduce_on_callback('on_batch_end', name, value, log=True, pbar=False, reduce_fx=reduce_fx) - if on_epoch_end: - self.__reduce_on_callback('on_epoch_end', name, value, log=True, pbar=False, reduce_fx=reduce_fx) - - @property - def log_on_batch_end(self): - return self.__getitem__('log_on_batch_end') - - @log_on_batch_end.setter - def log_on_batch_end(self, x): - if x is not None: - assert isinstance(x, dict), 'log_on_batch_end must be a dict' - self.__setitem__('log_on_batch_end', x) - - @property - def pbar_on_batch_end(self): - return self.__getitem__('pbar_on_batch_end') - - @pbar_on_batch_end.setter - def pbar_on_batch_end(self, x): - if x is not None: - assert isinstance(x, dict), 'pbar_on_batch_end must be a dict' - self.__setitem__('pbar_on_batch_end', x) - - @property - def log_on_epoch_end(self): - return self.__getitem__('log_on_epoch_end') - - @log_on_epoch_end.setter - def log_on_epoch_end(self, x): - if x is not None: - assert isinstance(x, dict), 'log_on_epoch_end must be a dict' - self.__setitem__('log_on_epoch_end', x) - - @property - def pbar_on_epoch_end(self): - return self.__getitem__('pbar_on_epoch_end') - - @pbar_on_epoch_end.setter - def pbar_on_epoch_end(self, x): - if x is not None: - assert isinstance(x, dict), 'pbar_on_epoch_end must be a dict' - self.__setitem__('pbar_on_epoch_end', x) - - @property - def progress_bar(self): - return self.__getitem__('progress_bar') - - @progress_bar.setter - def progress_bar(self, x): - if x is not None: - assert isinstance(x, dict), 'progress_bar_logs must be a dict' - self.__setitem__('progress_bar', x) - - @property - def logs(self): - return self.__getitem__('logs') - - @logs.setter - def logs(self, x): - if x is not None: - assert isinstance(x, dict), 'logs must be a dict' - self.__setitem__('logs', x) - - @property - def hiddens(self): - return self._hiddens - - @hiddens.setter - def hiddens(self, x): - if x is not None: - assert isinstance(x, Tensor), 'hiddens must be a torch.Tensor' - self._hiddens = x - self.__setitem__('hiddens', x) - - @property - def checkpoint_on(self): - # use minimize as default if no checkpoint_on is passed - if 'checkpoint_on' not in self: - minimize = self.__getitem__('minimize') - self.__setitem__('checkpoint_on', minimize) - - return self.__getitem__('checkpoint_on') - - @checkpoint_on.setter - def checkpoint_on(self, x): - if x is not None: - assert isinstance(x, Tensor), 'checkpoint_on must be a torch.Tensor' - self.__setitem__('checkpoint_on', x.detach()) - - @property - def early_stop_on(self): - # use minimize as default if no checkpoint_on is passed - if 'early_stop_on' not in self: - minimize = self.__getitem__('minimize') - self.__setitem__('early_stop_on', minimize) - - return self.__getitem__('early_stop_on') - - @early_stop_on.setter - def early_stop_on(self, x): - if x is not None: - assert isinstance(x, Tensor), 'early_stop_on must be a torch.Tensor' - self.__setitem__('early_stop_on', x.detach()) - - @property - def minimize(self): - return self.__getitem__('minimize') - - @minimize.setter - def minimize(self, x): - if x is not None: - assert isinstance(x, Tensor), 'metric to minimize must be a torch.Tensor' - m = 'the metric to minimize must have a computational graph. Minimize ' \ - 'can only be used in training_end, training_step_end, training_epoch_end' - assert x.grad_fn is not None, m - self.__setitem__('minimize', x) - - -if __name__ == '__main__': - import torch - result = Result() - result.log_metrics({'a': 2}) - result.minimize = torch.tensor(1) diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py index dc4972ab2b224..1ca318ef1fac8 100644 --- a/tests/base/deterministic_model.py +++ b/tests/base/deterministic_model.py @@ -1,6 +1,5 @@ import torch from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.step_result import Result from torch.utils.data import Dataset, DataLoader import numpy as np @@ -39,7 +38,7 @@ def assert_graph_count(self, result, count=1): counts = self.count_num_graphs(result) assert counts == count - def count_num_graphs(self, result: Result, num_graphs=0): + def count_num_graphs(self, result, num_graphs=0): for k, v in result.items(): if isinstance(v, torch.Tensor) and v.grad_fn is not None: num_graphs += 1 From ba602d03ed3636173c33529e5dc93b72decf87c4 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 09:54:52 -0400 Subject: [PATCH 09/11] refactored training_bath --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index fd9e7f365a611..27c4b6c19c4ac 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -612,7 +612,7 @@ def run_training_batch(self, batch, batch_idx): batch_log_metrics.append(opt_closure_result.training_step_output.log_metrics) self.add_progress_bar_metrics(opt_closure_result.training_step_output.pbar_on_batch_end) - # loss, training_step_output, training_step_output_for_epoch_end, hiddens + # track hiddens self.hiddens = opt_closure_result.hiddens # check if loss or model weights are nan From de98ce6b291d2128c0f8e7eb909d260ac5e68117 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 10:29:26 -0400 Subject: [PATCH 10/11] refactored training_bath --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 27c4b6c19c4ac..7408f6cbc1518 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -683,7 +683,7 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): opt_idx, optimizer, self.hiddens - )[0] + ).loss model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda: lambda_closure) From a64fdfab92aa01bc48d083ac03a0746dd70d1c26 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 10:51:31 -0400 Subject: [PATCH 11/11] refactored training_bath --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7408f6cbc1518..80dde617a2669 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -686,7 +686,7 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): ).loss model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, - lambda: lambda_closure) + lambda_closure) return grad_norm_dic