Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove .item which causes sync issues #1254

Merged
merged 3 commits into from
Mar 30, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,8 +1520,10 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]:
Return:
Dictionary with the items to be displayed in the progress bar.
"""
# call .item() only once but store elements without graphs
running_training_loss = self.trainer.running_loss.mean().cpu().item()
tqdm_dict = {
'loss': '{:.3f}'.format(self.trainer.avg_loss)
'loss': '{:.3f}'.format(running_training_loss)
}

if self.trainer.truncated_bptt_steps is not None:
Expand Down
34 changes: 34 additions & 0 deletions pytorch_lightning/trainer/supporting_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about creating a metrics package?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not relevant to this



class TensorRunningMean(object):
"""
Tracks a running mean without graph references.
Round robbin for the mean
"""
def __init__(self, window_length):
self.window_length = window_length
self.reset()

def reset(self):
self.memory = torch.Tensor(self.window_length)
self.current_idx = 0

def append(self, x):
# map proper type for memory if they don't match
if self.memory.type() != x.type():
self.memory.type_as(x)

# store without grads
with torch.no_grad():
self.memory[self.current_idx] = x

# increase index
self.current_idx += 1

# reset index when hit limit of tensor
if self.current_idx >= self.window_length:
self.current_idx = 0

def mean(self):
return self.memory.mean()
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.trainer.supporting_classes import TensorRunningMean

try:
from apex import amp
Expand Down Expand Up @@ -340,8 +341,7 @@ def __init__(

# training bookeeping
self.total_batch_idx = 0
self.running_loss = []
self.avg_loss = 0
self.running_loss = TensorRunningMean(window_length=20)
self.batch_idx = 0
self.tqdm_metrics = {}
self.callback_metrics = {}
Expand Down
22 changes: 15 additions & 7 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.trainer.supporting_classes import TensorRunningMean

try:
from apex import amp
Expand Down Expand Up @@ -324,7 +325,14 @@ def train(self):

# total batches includes multiple val checks
self.total_batches = self.num_training_batches + total_val_batches
self.batch_loss_value = 0 # accumulated grads

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_start(self, self.get_model())

# stores accumulated grad fractions per batch
self.batch_loss_value = TensorRunningMean(
window_length=self.accumulate_grad_batches
)

if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
Expand Down Expand Up @@ -380,8 +388,7 @@ def run_training_epoch(self):
with self.profiler.profile('on_epoch_start'):
# callbacks
self.on_epoch_start()
# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_start(self, self.get_model())

# model hooks
if self.is_function_implemented('on_epoch_start'):
self.get_model().on_epoch_start()
Expand Down Expand Up @@ -570,7 +577,7 @@ def optimizer_closure():
self.detect_nan_tensors(loss)

# track total loss for logging (avoid mem leaks)
self.batch_loss_value += loss.item()
self.batch_loss_value.append(loss)

# gradient update with accumulated gradients
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
Expand All @@ -593,9 +600,10 @@ def optimizer_closure():
optimizer, opt_idx, optimizer_closure)

# calculate running loss for display
self.running_loss.append(self.batch_loss_value)
self.batch_loss_value = 0
self.avg_loss = np.mean(self.running_loss[-100:])
self.running_loss.append(self.batch_loss_value.mean())

# reset for next set of accumulated grads
self.batch_loss_value.reset()

# Batch end events
with self.profiler.profile('on_batch_end'):
Expand Down