Skip to content

Commit

Permalink
fix reduction docstring and clean tests (#2885)
Browse files Browse the repository at this point in the history
* fix reduction docstring

* Update docstring and some cleanup

* miss

* suggestion from code review

Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>

Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
  • Loading branch information
rohitgr7 and ananyahjha93 authored Aug 9, 2020
1 parent 6ad2718 commit 983c030
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 81 deletions.
55 changes: 24 additions & 31 deletions pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@

from pytorch_lightning.metrics.functional.classification import (
accuracy,
confusion_matrix,
precision_recall_curve,
precision,
recall,
average_precision,
auroc,
fbeta_score,
f1_score,
roc,
multiclass_roc,
multiclass_precision_recall_curve,
average_precision,
confusion_matrix,
dice_score,
f1_score,
fbeta_score,
iou,
multiclass_precision_recall_curve,
multiclass_roc,
precision,
precision_recall_curve,
recall,
roc
)
from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric
from pytorch_lightning.metrics.metric import TensorCollectionMetric, TensorMetric


class Accuracy(TensorMetric):
Expand All @@ -45,7 +45,7 @@ def __init__(
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
Expand Down Expand Up @@ -208,7 +208,7 @@ def __init__(
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
Expand Down Expand Up @@ -262,7 +262,7 @@ def __init__(
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
Expand Down Expand Up @@ -428,7 +428,7 @@ def __init__(
Args:
beta: determines the weight of recall in the combined score.
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
Expand Down Expand Up @@ -484,7 +484,7 @@ def __init__(
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
Expand Down Expand Up @@ -605,11 +605,6 @@ def __init__(
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
Expand Down Expand Up @@ -669,11 +664,6 @@ def __init__(
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Expand Down Expand Up @@ -737,7 +727,7 @@ def __init__(
include_background: whether to also compute dice for the background
nan_score: score to return, if a NaN occurs during computation (denom zero)
no_fg_score: score to return, if no foreground pixel was found in target
reduction: a method for reducing accuracies over labels (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
Expand Down Expand Up @@ -790,16 +780,19 @@ class IoU(TensorMetric):
tensor(0.7045)
"""
def __init__(self,
remove_bg: bool = False,
reduction: str = 'elementwise_mean'):

def __init__(
self,
remove_bg: bool = False,
reduction: str = 'elementwise_mean'
):
"""
Args:
remove_bg: Flag to state whether a background class has been included
within input parameters. If true, will remove background class. If
false, return IoU over all classes.
Assumes that background is '0' class in input tensor
reduction: a method for reducing IoU over labels (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import numbers
from typing import Union, Any, Callable, Optional
from typing import Any, Callable, Optional, Union

import numpy as np
import torch
Expand Down
21 changes: 11 additions & 10 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import functional as F

from pytorch_lightning.metrics.functional.reduction import reduce
from pytorch_lightning.utilities import rank_zero_warn, FLOAT16_EPSILON
from pytorch_lightning.utilities import FLOAT16_EPSILON, rank_zero_warn


def to_onehot(
Expand Down Expand Up @@ -149,7 +149,7 @@ def stat_scores_multiple_classes(
num_classes: number of classes if known
argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over
reduction: method for reducing result values (default: none)
reduction: a method to reduce metric score over labels (default: none)
Available reduction methods:
- elementwise_mean: takes the mean
Expand All @@ -174,6 +174,7 @@ def stat_scores_multiple_classes(
tensor([1., 0., 0., 0.])
>>> sups
tensor([1., 0., 1., 1.])
"""
if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)
Expand Down Expand Up @@ -247,7 +248,7 @@ def accuracy(
pred: predicted labels
target: ground truth labels
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
Expand Down Expand Up @@ -327,7 +328,7 @@ def precision_recall(
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing precision-recall values (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
Expand Down Expand Up @@ -376,7 +377,7 @@ def precision(
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing precision values (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
Expand Down Expand Up @@ -411,7 +412,7 @@ def recall(
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing recall values (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
Expand Down Expand Up @@ -452,7 +453,7 @@ def fbeta_score(
beta = 0: only precision
beta -> inf: only recall
num_classes: number of classes
reduction: method for reducing F-score (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
Expand Down Expand Up @@ -497,7 +498,7 @@ def f1_score(
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing F1-score (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
Expand Down Expand Up @@ -920,7 +921,7 @@ def dice_score(
bg: whether to also compute dice for the background
nan_score: score to return, if a NaN occurs during computation
no_fg_score: score to return, if no foreground pixel was found in target
reduction: a method for reducing accuracies over labels (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
Expand Down Expand Up @@ -977,7 +978,7 @@ def iou(
within input parameters. If true, will remove background class. If
false, return IoU over all classes
Assumes that background is '0' class in input tensor
reduction: a method for reducing IoU over labels (default: takes the mean)
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
Expand Down
23 changes: 17 additions & 6 deletions pytorch_lightning/metrics/functional/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from collections import Counter
from typing import Sequence, List
from typing import List, Sequence

import torch


def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
"""Counting how many times each word appears in a given text with ngram
"""
Counting how many times each word appears in a given text with ngram
Args:
ngram_input_list: A list of translated text or reference texts
Expand All @@ -24,16 +25,20 @@ def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:

for i in range(1, n_gram + 1):
for j in range(len(ngram_input_list) - i + 1):
ngram_key = tuple(ngram_input_list[j : i + j])
ngram_key = tuple(ngram_input_list[j:(i + j)])
ngram_counter[ngram_key] += 1

return ngram_counter


def bleu_score(
translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False
translate_corpus: Sequence[str],
reference_corpus: Sequence[str],
n_gram: int = 4,
smooth: bool = False
) -> torch.Tensor:
"""Calculate BLEU score of machine translated text with one or more references.
"""
Calculate BLEU score of machine translated text with one or more references
Args:
translate_corpus: An iterable of machine translated corpus
Expand All @@ -42,14 +47,15 @@ def bleu_score(
smooth: Whether or not to apply smoothing – Lin et al. 2004
Return:
A Tensor with BLEU Score
Tensor with BLEU Score
Example:
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> bleu_score(translate_corpus, reference_corpus)
tensor(0.7598)
"""

assert len(translate_corpus) == len(reference_corpus)
Expand All @@ -58,17 +64,20 @@ def bleu_score(
precision_scores = torch.zeros(n_gram)
c = 0.0
r = 0.0

for (translation, references) in zip(translate_corpus, reference_corpus):
c += len(translation)
ref_len_list = [len(ref) for ref in references]
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
translation_counter = _count_ngram(translation, n_gram)
reference_counter = Counter()

for ref in references:
reference_counter |= _count_ngram(ref, n_gram)

ngram_counter_clip = translation_counter & reference_counter

for counter_clip in ngram_counter_clip:
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]

Expand All @@ -77,13 +86,15 @@ def bleu_score(

trans_len = torch.tensor(c)
ref_len = torch.tensor(r)

if min(numerator) == 0.0:
return torch.tensor(0.0)

if smooth:
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
else:
precision_scores = numerator / denominator

log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
Expand Down
Loading

0 comments on commit 983c030

Please sign in to comment.