diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4fe23ef7fcb7d..ee18a40436c0c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1524,8 +1524,10 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: Return: Dictionary with the items to be displayed in the progress bar. """ + # call .item() only once but store elements without graphs + running_training_loss = self.trainer.running_loss.mean().cpu().item() tqdm_dict = { - 'loss': '{:.3f}'.format(self.trainer.avg_loss) + 'loss': '{:.3f}'.format(running_training_loss) } if self.trainer.truncated_bptt_steps is not None: diff --git a/pytorch_lightning/trainer/supporting_classes.py b/pytorch_lightning/trainer/supporting_classes.py new file mode 100644 index 0000000000000..7f2b0824a63a6 --- /dev/null +++ b/pytorch_lightning/trainer/supporting_classes.py @@ -0,0 +1,39 @@ +import torch + + +class TensorRunningMean(object): + """ + Tracks a running mean without graph references. + Round robbin for the mean + """ + def __init__(self, window_length): + self.window_length = window_length + self.reset() + self.last_idx = 0 + + def reset(self): + self.memory = torch.Tensor(self.window_length) + self.current_idx = 0 + + def last(self): + return self.memory[self.last_idx] + + def append(self, x): + # map proper type for memory if they don't match + if self.memory.type() != x.type(): + self.memory.type_as(x) + + # store without grads + with torch.no_grad(): + self.memory[self.current_idx] = x + self.last_idx = self.current_idx + + # increase index + self.current_idx += 1 + + # reset index when hit limit of tensor + if self.current_idx >= self.window_length: + self.current_idx = 0 + + def mean(self): + return self.memory.mean() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8a4d4dc6cd3e0..9994af6c2d108 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -34,6 +34,7 @@ from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.trainer.supporting_classes import TensorRunningMean try: from apex import amp @@ -340,8 +341,7 @@ def __init__( # training bookeeping self.total_batch_idx = 0 - self.running_loss = [] - self.avg_loss = 0 + self.running_loss = TensorRunningMean(window_length=20) self.batch_idx = 0 self.tqdm_metrics = {} self.callback_metrics = {} diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8cfcd7043320d..cc44882ea5a58 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -146,6 +146,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.trainer.supporting_classes import TensorRunningMean try: from apex import amp @@ -324,7 +325,14 @@ def train(self): # total batches includes multiple val checks self.total_batches = self.num_training_batches + total_val_batches - self.batch_loss_value = 0 # accumulated grads + + # changing gradient according accumulation_scheduler + self.accumulation_scheduler.on_epoch_start(self, self.get_model()) + + # stores accumulated grad fractions per batch + self.batch_loss_value = TensorRunningMean( + window_length=self.accumulate_grad_batches + ) if self.fast_dev_run: # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run @@ -380,8 +388,7 @@ def run_training_epoch(self): with self.profiler.profile('on_epoch_start'): # callbacks self.on_epoch_start() - # changing gradient according accumulation_scheduler - self.accumulation_scheduler.on_epoch_start(self, self.get_model()) + # model hooks if self.is_function_implemented('on_epoch_start'): self.get_model().on_epoch_start() @@ -572,7 +579,7 @@ def optimizer_closure(): self.detect_nan_tensors(loss) # track total loss for logging (avoid mem leaks) - self.batch_loss_value += loss.item() + self.batch_loss_value.append(loss) # gradient update with accumulated gradients if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: @@ -595,9 +602,10 @@ def optimizer_closure(): optimizer, opt_idx, optimizer_closure) # calculate running loss for display - self.running_loss.append(self.batch_loss_value) - self.batch_loss_value = 0 - self.avg_loss = np.mean(self.running_loss[-100:]) + self.running_loss.append(self.batch_loss_value.mean()) + + # reset for next set of accumulated grads + self.batch_loss_value.reset() # Batch end events with self.profiler.profile('on_batch_end'):