Skip to content

Commit

Permalink
add function to reduce tensors (similar to reduction in torch.nn)
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Apr 13, 2020
1 parent 79f0731 commit b193059
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions pytorch_lightning/metrics/functional/reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch


def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
"""
reduces a given tensor by a given reduction method
Parameters
----------
to_reduce : torch.Tensor
the tensor, which shall be reduced
reduction : str
a string specifying the reduction method.
should be one of 'elementwise_mean' | 'none' | 'sum'
Returns
-------
torch.Tensor
reduced Tensor
Raises
------
ValueError
if an invalid reduction parameter was given
"""
if reduction == 'elementwise_mean':
return torch.mean(to_reduce)
if reduction == 'none':
return to_reduce
if reduction == 'sum':
return torch.sum(to_reduce)
raise ValueError('Reduction parameter unknown.')

0 comments on commit b193059

Please sign in to comment.