From b193059032eb0591c44f5dde997fa66f41bf93ad Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 13 Apr 2020 16:45:09 +0200 Subject: [PATCH] add function to reduce tensors (similar to reduction in torch.nn) --- .../metrics/functional/reduction.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 pytorch_lightning/metrics/functional/reduction.py diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py new file mode 100644 index 0000000000000..d889f7013a34f --- /dev/null +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -0,0 +1,29 @@ +import torch + + +def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: + """ + reduces a given tensor by a given reduction method + Parameters + ---------- + to_reduce : torch.Tensor + the tensor, which shall be reduced + reduction : str + a string specifying the reduction method. + should be one of 'elementwise_mean' | 'none' | 'sum' + Returns + ------- + torch.Tensor + reduced Tensor + Raises + ------ + ValueError + if an invalid reduction parameter was given + """ + if reduction == 'elementwise_mean': + return torch.mean(to_reduce) + if reduction == 'none': + return to_reduce + if reduction == 'sum': + return torch.sum(to_reduce) + raise ValueError('Reduction parameter unknown.')