Skip to content

Commit

Permalink
remove .item which causes sync issues (Lightning-AI#1254)
Browse files Browse the repository at this point in the history
* remove .item which causes sync issues

* fixed gradient acc sched

* fixed gradient acc sched
  • Loading branch information
williamFalcon authored and akarnachev committed Apr 3, 2020
1 parent c195cc2 commit 48ff0eb
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions pytorch_lightning/trainer/supporting_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch


class TensorRunningMean(object):
"""
Tracks a running mean without graph references.
Round robbin for the mean
"""
def __init__(self, window_length):
self.window_length = window_length
self.reset()
self.last_idx = 0

def reset(self):
self.memory = torch.Tensor(self.window_length)
self.current_idx = 0

def last(self):
return self.memory[self.last_idx]

def append(self, x):
# map proper type for memory if they don't match
if self.memory.type() != x.type():
self.memory.type_as(x)

# store without grads
with torch.no_grad():
self.memory[self.current_idx] = x
self.last_idx = self.current_idx

# increase index
self.current_idx += 1

# reset index when hit limit of tensor
if self.current_idx >= self.window_length:
self.current_idx = 0

def mean(self):
return self.memory.mean()

0 comments on commit 48ff0eb

Please sign in to comment.