Skip to content

Commit

Permalink
fix docs
Browse files Browse the repository at this point in the history
  • Loading branch information
cuent committed May 26, 2020
1 parent 3854e0c commit 1ba3dd3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
2 changes: 2 additions & 0 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import Sequence
from typing import Optional, Tuple, Callable
from functools import wraps

import torch

Expand Down Expand Up @@ -326,6 +327,7 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):

def auc_decorator(reorder: bool = True) -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
@wraps(func_to_decorate)
def new_func(*args, **kwargs) -> torch.Tensor:
x, y = func_to_decorate(*args, **kwargs)[:2]

Expand Down
24 changes: 9 additions & 15 deletions pytorch_lightning/metrics/functional/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,16 @@

def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
"""
reduces a given tensor by a given reduction method
Parameters
----------
to_reduce : torch.Tensor
the tensor, which shall be reduced
reduction : str
a string specifying the reduction method.
should be one of 'elementwise_mean' | 'none' | 'sum'
Returns
-------
torch.Tensor
Reduces a given tensor by a given reduction method
Args:
to_reduce : the tensor, which shall be reduced
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
Returns:
reduced Tensor
Raises
------
ValueError
if an invalid reduction parameter was given
Raises: ValueError if an invalid reduction parameter was given
"""
if reduction == 'elementwise_mean':
return torch.mean(to_reduce)
Expand Down

0 comments on commit 1ba3dd3

Please sign in to comment.