From c586217031a71d30ce0aed5fae1675002ff64473 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 29 Mar 2020 20:20:34 -0400 Subject: [PATCH] remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched --- .../trainer/supporting_classes.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 pytorch_lightning/trainer/supporting_classes.py diff --git a/pytorch_lightning/trainer/supporting_classes.py b/pytorch_lightning/trainer/supporting_classes.py new file mode 100644 index 0000000000000..7f2b0824a63a6 --- /dev/null +++ b/pytorch_lightning/trainer/supporting_classes.py @@ -0,0 +1,39 @@ +import torch + + +class TensorRunningMean(object): + """ + Tracks a running mean without graph references. + Round robbin for the mean + """ + def __init__(self, window_length): + self.window_length = window_length + self.reset() + self.last_idx = 0 + + def reset(self): + self.memory = torch.Tensor(self.window_length) + self.current_idx = 0 + + def last(self): + return self.memory[self.last_idx] + + def append(self, x): + # map proper type for memory if they don't match + if self.memory.type() != x.type(): + self.memory.type_as(x) + + # store without grads + with torch.no_grad(): + self.memory[self.current_idx] = x + self.last_idx = self.current_idx + + # increase index + self.current_idx += 1 + + # reset index when hit limit of tensor + if self.current_idx >= self.window_length: + self.current_idx = 0 + + def mean(self): + return self.memory.mean()