From 882265874520981bc704a3540a3cad0b7e809d44 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 26 Mar 2020 19:50:23 -0400 Subject: [PATCH 1/3] remove .item which causes sync issues --- pytorch_lightning/core/lightning.py | 4 ++- .../trainer/supporting_classes.py | 34 +++++++++++++++++++ pytorch_lightning/trainer/trainer.py | 4 +-- pytorch_lightning/trainer/training_loop.py | 16 ++++++--- 4 files changed, 50 insertions(+), 8 deletions(-) create mode 100644 pytorch_lightning/trainer/supporting_classes.py diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2a67d32748c42..d9decbf266187 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1520,8 +1520,10 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: Return: Dictionary with the items to be displayed in the progress bar. """ + # call .item() only once but store elements without graphs + running_training_loss = self.trainer.running_loss.mean().cpu().item() tqdm_dict = { - 'loss': '{:.3f}'.format(self.trainer.avg_loss) + 'loss': '{:.3f}'.format(running_training_loss) } if self.trainer.truncated_bptt_steps is not None: diff --git a/pytorch_lightning/trainer/supporting_classes.py b/pytorch_lightning/trainer/supporting_classes.py new file mode 100644 index 0000000000000..e5b811e5c2d07 --- /dev/null +++ b/pytorch_lightning/trainer/supporting_classes.py @@ -0,0 +1,34 @@ +import torch + + +class TensorRunningMean(object): + """ + Tracks a running mean without graph references. + Round robbin for the mean + """ + def __init__(self, window_length): + self.window_length = window_length + self.reset() + + def reset(self): + self.memory = torch.Tensor(self.window_length) + self.current_idx = 0 + + def append(self, x): + # map proper type for memory if they don't match + if self.memory.type() != x.type(): + self.memory.type_as(x) + + # store without grads + with torch.no_grad(): + self.memory[self.current_idx] = x + + # increase index + self.current_idx += 1 + + # reset index when hit limit of tensor + if self.current_idx >= self.window_length: + self.current_idx = 0 + + def mean(self): + return self.memory.mean() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 881a2e9103301..3d31ef82c7be9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -34,6 +34,7 @@ from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.trainer.supporting_classes import TensorRunningMean try: from apex import amp @@ -340,8 +341,7 @@ def __init__( # training bookeeping self.total_batch_idx = 0 - self.running_loss = [] - self.avg_loss = 0 + self.running_loss = TensorRunningMean(window_length=20) self.batch_idx = 0 self.tqdm_metrics = {} self.callback_metrics = {} diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a90a43b0d8beb..a0da0f2f398fc 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -146,6 +146,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.trainer.supporting_classes import TensorRunningMean try: from apex import amp @@ -324,7 +325,11 @@ def train(self): # total batches includes multiple val checks self.total_batches = self.num_training_batches + total_val_batches - self.batch_loss_value = 0 # accumulated grads + + # stores accumulated grad fractions per batch + self.batch_loss_value = TensorRunningMean( + window_length=self.accumulate_grad_batches + ) if self.fast_dev_run: # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run @@ -570,7 +575,7 @@ def optimizer_closure(): self.detect_nan_tensors(loss) # track total loss for logging (avoid mem leaks) - self.batch_loss_value += loss.item() + self.batch_loss_value.append(loss) # gradient update with accumulated gradients if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: @@ -593,9 +598,10 @@ def optimizer_closure(): optimizer, opt_idx, optimizer_closure) # calculate running loss for display - self.running_loss.append(self.batch_loss_value) - self.batch_loss_value = 0 - self.avg_loss = np.mean(self.running_loss[-100:]) + self.running_loss.append(self.batch_loss_value.mean()) + + # reset for next set of accumulated grads + self.batch_loss_value.reset() # Batch end events with self.profiler.profile('on_batch_end'): From 97ba22d8ab12cdd6d8cd3767a03e441af3aee9a3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 29 Mar 2020 15:47:22 -0400 Subject: [PATCH 2/3] fixed gradient acc sched --- pytorch_lightning/trainer/training_loop.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a0da0f2f398fc..9153dc29c3358 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -326,6 +326,9 @@ def train(self): # total batches includes multiple val checks self.total_batches = self.num_training_batches + total_val_batches + # changing gradient according accumulation_scheduler + self.accumulation_scheduler.on_epoch_start(self, self.get_model()) + # stores accumulated grad fractions per batch self.batch_loss_value = TensorRunningMean( window_length=self.accumulate_grad_batches @@ -385,8 +388,7 @@ def run_training_epoch(self): with self.profiler.profile('on_epoch_start'): # callbacks self.on_epoch_start() - # changing gradient according accumulation_scheduler - self.accumulation_scheduler.on_epoch_start(self, self.get_model()) + # model hooks if self.is_function_implemented('on_epoch_start'): self.get_model().on_epoch_start() From 747b0e584cfd402393b92e9954a70d5ddc061662 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 29 Mar 2020 16:06:14 -0400 Subject: [PATCH 3/3] fixed gradient acc sched --- pytorch_lightning/trainer/supporting_classes.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/trainer/supporting_classes.py b/pytorch_lightning/trainer/supporting_classes.py index e5b811e5c2d07..7f2b0824a63a6 100644 --- a/pytorch_lightning/trainer/supporting_classes.py +++ b/pytorch_lightning/trainer/supporting_classes.py @@ -9,11 +9,15 @@ class TensorRunningMean(object): def __init__(self, window_length): self.window_length = window_length self.reset() + self.last_idx = 0 def reset(self): self.memory = torch.Tensor(self.window_length) self.current_idx = 0 + def last(self): + return self.memory[self.last_idx] + def append(self, x): # map proper type for memory if they don't match if self.memory.type() != x.type(): @@ -22,6 +26,7 @@ def append(self, x): # store without grads with torch.no_grad(): self.memory[self.current_idx] = x + self.last_idx = self.current_idx # increase index self.current_idx += 1