Skip to content

Commit

Permalink
remove .item which causes sync issues (#1254)
Browse files Browse the repository at this point in the history
* remove .item which causes sync issues

* fixed gradient acc sched

* fixed gradient acc sched
  • Loading branch information
williamFalcon committed Mar 30, 2020
1 parent b74a3c5 commit 31b7148
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 10 deletions.
4 changes: 3 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,8 +1524,10 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]:
Return:
Dictionary with the items to be displayed in the progress bar.
"""
# call .item() only once but store elements without graphs
running_training_loss = self.trainer.running_loss.mean().cpu().item()
tqdm_dict = {
'loss': '{:.3f}'.format(self.trainer.avg_loss)
'loss': '{:.3f}'.format(running_training_loss)
}

if self.trainer.truncated_bptt_steps is not None:
Expand Down
39 changes: 39 additions & 0 deletions pytorch_lightning/trainer/supporting_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch


class TensorRunningMean(object):
"""
Tracks a running mean without graph references.
Round robbin for the mean
"""
def __init__(self, window_length):
self.window_length = window_length
self.reset()
self.last_idx = 0

def reset(self):
self.memory = torch.Tensor(self.window_length)
self.current_idx = 0

def last(self):
return self.memory[self.last_idx]

def append(self, x):
# map proper type for memory if they don't match
if self.memory.type() != x.type():
self.memory.type_as(x)

# store without grads
with torch.no_grad():
self.memory[self.current_idx] = x
self.last_idx = self.current_idx

# increase index
self.current_idx += 1

# reset index when hit limit of tensor
if self.current_idx >= self.window_length:
self.current_idx = 0

def mean(self):
return self.memory.mean()
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.trainer.supporting_classes import TensorRunningMean

try:
from apex import amp
Expand Down Expand Up @@ -340,8 +341,7 @@ def __init__(

# training bookeeping
self.total_batch_idx = 0
self.running_loss = []
self.avg_loss = 0
self.running_loss = TensorRunningMean(window_length=20)
self.batch_idx = 0
self.tqdm_metrics = {}
self.callback_metrics = {}
Expand Down
22 changes: 15 additions & 7 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.trainer.supporting_classes import TensorRunningMean

try:
from apex import amp
Expand Down Expand Up @@ -324,7 +325,14 @@ def train(self):

# total batches includes multiple val checks
self.total_batches = self.num_training_batches + total_val_batches
self.batch_loss_value = 0 # accumulated grads

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_start(self, self.get_model())

# stores accumulated grad fractions per batch
self.batch_loss_value = TensorRunningMean(
window_length=self.accumulate_grad_batches
)

if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
Expand Down Expand Up @@ -380,8 +388,7 @@ def run_training_epoch(self):
with self.profiler.profile('on_epoch_start'):
# callbacks
self.on_epoch_start()
# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_start(self, self.get_model())

# model hooks
if self.is_function_implemented('on_epoch_start'):
self.get_model().on_epoch_start()
Expand Down Expand Up @@ -572,7 +579,7 @@ def optimizer_closure():
self.detect_nan_tensors(loss)

# track total loss for logging (avoid mem leaks)
self.batch_loss_value += loss.item()
self.batch_loss_value.append(loss)

# gradient update with accumulated gradients
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
Expand All @@ -595,9 +602,10 @@ def optimizer_closure():
optimizer, opt_idx, optimizer_closure)

# calculate running loss for display
self.running_loss.append(self.batch_loss_value)
self.batch_loss_value = 0
self.avg_loss = np.mean(self.running_loss[-100:])
self.running_loss.append(self.batch_loss_value.mean())

# reset for next set of accumulated grads
self.batch_loss_value.reset()

# Batch end events
with self.profiler.profile('on_batch_end'):
Expand Down

0 comments on commit 31b7148

Please sign in to comment.