Skip to content

Commit

Permalink
Add False{Negative,Positive}FeatureSampler metrics.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645188718
  • Loading branch information
embr authored and tfx-copybara committed Jun 20, 2024
1 parent 52a5cc4 commit 84a9c2b
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 2 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

## Major Features and Improvements

* Adds `False{Negative,Positive}FeatureSampler` metrics.

## Bug fixes and other Changes

## Breaking Changes
Expand Down
240 changes: 240 additions & 0 deletions tensorflow_model_analysis/metrics/confusion_matrix_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
DIAGNOSTIC_ODDS_RATIO_NAME = 'diagnostic_odds_ratio'
PREDICTED_POSITIVE_RATE_NAME = 'predicted_positive_rate'
CONFUSION_MATRIX_AT_THRESHOLDS_NAME = 'confusion_matrix_at_thresholds'
FALSE_POSITIVE_FEATURE_SAMPLER_NAME = 'false_positive_feature_sampler'
FALSE_NEGATIVE_FEATURE_SAMPLER_NAME = 'false_negative_feature_sampler'
AVERAGE_PRECISION_NAME = 'average_precision'
MAX_RECALL_NAME = 'max_recall'
THRESHOLD_AT_RECALL_NAME = 'threshold_at_recall'
Expand Down Expand Up @@ -2248,6 +2250,244 @@ def result(self, tp: float, tn: float, fp: float, fn: float) -> float:
metric_types.register_metric(PredictedPositiveRate)


class ConfusionMatrixFeatureSamplerBase(
metric_types.Metric, metaclass=abc.ABCMeta
):
"""Base class for metrics that sample features per confusion matrix case."""

def __init__(
self,
feature_key: str,
sample_size: int,
threshold: float,
name: Optional[str] = None,
top_k: Optional[int] = None,
class_id: Optional[int] = None,
):
"""Initializes confusion matrix samples at thresholds.
Args:
feature_key: Feature key to sample.
sample_size: Number of samples to collect per confusion matrix case.
threshold: (Optional) Defaults to [0.5]. A float value in [0, 1]. A
threshold is compared with prediction values to determine the truth
value of predictions (i.e., above the threshold is `true`, below is
`false`). One metric value is generated for each threshold value.
name: (Optional) Metric name.
top_k: (Optional) Used with a multi-class model to specify that the top-k
values should be used to compute the confusion matrix. The net effect is
that the non-top-k values are set to -inf and the matrix is then
constructed from the average TP, FP, TN, FN across the classes. When
top_k is used, metrics_specs.binarize settings must not be present. Only
one of class_id or top_k should be configured. When top_k is set, the
default thresholds are [float('-inf')].
class_id: (Optional) Used with a multi-class model to specify which class
to compute the confusion matrix for. When class_id is used,
metrics_specs.binarize settings must not be present. Only one of
class_id or top_k should be configured.
"""
super().__init__(
metric_util.merge_per_key_computations(self._metric_computations),
threshold=threshold,
feature_key=feature_key,
sample_size=sample_size,
name=name,
top_k=top_k,
class_id=class_id,
)

@abc.abstractmethod
def _get_samples(
self, examples: binary_confusion_matrices.Examples
) -> np.ndarray:
"""Returns the samples for the given examples.
Note that the storage format for examples supports multiple thresholds,
however
this base class only supports a single threshold. This means that a typical
_get_samples implementation should index into the first element for each
confusion matrix case, as in examples.tp_examples[0].
Args:
examples: The binary_confusion_matrices.Examples NamedTuple object from
which to get the appropriate samples.
"""

def _metric_computations(
self,
feature_key: str,
sample_size: int,
threshold: Optional[float] = None,
top_k: Optional[int] = None,
class_id: Optional[int] = None,
name: Optional[str] = None,
eval_config: Optional[config_pb2.EvalConfig] = None,
model_name: str = '',
output_name: str = '',
sub_key: Optional[metric_types.SubKey] = None,
aggregation_type: Optional[metric_types.AggregationType] = None,
class_weights: Optional[Dict[int, float]] = None,
example_weighted: bool = False,
) -> metric_types.MetricComputations:
"""Returns metric computations for confusion matrix at thresholds."""
sub_key = _validate_and_update_sub_key(
name, model_name, output_name, sub_key, top_k, class_id
)

# Make sure matrices are calculated with examples
matrices_computations = binary_confusion_matrices.binary_confusion_matrices(
thresholds=[threshold],
example_id_key=feature_key,
example_ids_count=sample_size,
use_histogram=False,
preprocessors=[
metric_types.FeaturePreprocessor(feature_keys=[feature_key])
],
eval_config=eval_config,
model_name=model_name,
output_name=output_name,
sub_key=sub_key,
aggregation_type=aggregation_type,
class_weights=class_weights,
example_weighted=example_weighted,
)
examples_key = matrices_computations[-1].keys[0]

output_key = metric_types.MetricKey(
name=name or self._default_name(),
model_name=model_name,
output_name=output_name,
sub_key=sub_key,
example_weighted=example_weighted,
aggregation_type=aggregation_type,
)

def result(metrics):
metrics[output_key] = self._get_samples(metrics[examples_key])
return metrics

derived_computation = metric_types.DerivedMetricComputation(
keys=[], result=result
)
computations = matrices_computations
computations.append(derived_computation)
return computations


class FalsePositiveFeatureSampler(ConfusionMatrixFeatureSamplerBase):
"""False positive feature samples."""

def __init__(
self,
feature_key: str,
sample_size: int,
threshold: float = 0.5,
name: Optional[str] = None,
top_k: Optional[int] = None,
class_id: Optional[int] = None,
):
"""Initializes FalsePositiveFeatureSampler metric.
Args:
feature_key: Feature key to sample.
sample_size: Number of samples to collect per confusion matrix case.
threshold: (Optional) Defaults to [0.5]. A float value in [0, 1]. A
threshold is compared with prediction values to determine the truth
value of predictions (i.e., above the threshold is `true`, below is
`false`). One metric value is generated for each threshold value.
name: (Optional) Metric name.
top_k: (Optional) Used with a multi-class model to specify that the top-k
values should be used to compute the confusion matrix. The net effect is
that the non-top-k values are set to -inf and the matrix is then
constructed from the average TP, FP, TN, FN across the classes. When
top_k is used, metrics_specs.binarize settings must not be present. Only
one of class_id or top_k should be configured. When top_k is set, the
default thresholds are [float('-inf')].
class_id: (Optional) Used with a multi-class model to specify which class
to compute the confusion matrix for. When class_id is used,
metrics_specs.binarize settings must not be present. Only one of
class_id or top_k should be configured.
"""
super().__init__(
feature_key=feature_key,
sample_size=sample_size,
threshold=threshold,
name=name,
top_k=top_k,
class_id=class_id,
)

def _get_samples(
self, examples: binary_confusion_matrices.Examples
) -> np.ndarray:
assert len(examples.fp_examples) == 1, 'Expected exactly one threshold'
result = np.concatenate(examples.fp_examples[0])
return result

def _default_name(self) -> str:
return FALSE_POSITIVE_FEATURE_SAMPLER_NAME


metric_types.register_metric(FalsePositiveFeatureSampler)


class FalseNegativeFeatureSampler(ConfusionMatrixFeatureSamplerBase):
"""False negative feature samples."""

def __init__(
self,
feature_key: str,
sample_size: int,
threshold: float = 0.5,
name: Optional[str] = None,
top_k: Optional[int] = None,
class_id: Optional[int] = None,
):
"""Initializes FalseNegativeFeatureSampler metric.
Args:
feature_key: Feature key to sample.
sample_size: Number of samples to collect per confusion matrix case.
threshold: (Optional) Defaults to [0.5]. A float value in [0, 1]. A
threshold is compared with prediction values to determine the truth
value of predictions (i.e., above the threshold is `true`, below is
`false`). One metric value is generated for each threshold value.
name: (Optional) Metric name.
top_k: (Optional) Used with a multi-class model to specify that the top-k
values should be used to compute the confusion matrix. The net effect is
that the non-top-k values are set to -inf and the matrix is then
constructed from the average TP, FP, TN, FN across the classes. When
top_k is used, metrics_specs.binarize settings must not be present. Only
one of class_id or top_k should be configured. When top_k is set, the
default thresholds are [float('-inf')].
class_id: (Optional) Used with a multi-class model to specify which class
to compute the confusion matrix for. When class_id is used,
metrics_specs.binarize settings must not be present. Only one of
class_id or top_k should be configured.
"""
super().__init__(
feature_key=feature_key,
sample_size=sample_size,
threshold=threshold,
name=name,
top_k=top_k,
class_id=class_id,
)

def _get_samples(
self, examples: binary_confusion_matrices.Examples
) -> np.ndarray:
assert len(examples.fp_examples) == 1, 'Expected exactly one threshold'
result = np.concatenate(examples.fn_examples[0])
return result

def _default_name(self) -> str:
return FALSE_NEGATIVE_FEATURE_SAMPLER_NAME


metric_types.register_metric(FalseNegativeFeatureSampler)


class ConfusionMatrixAtThresholds(metric_types.Metric):
"""Confusion matrix at thresholds."""

Expand Down
70 changes: 68 additions & 2 deletions tensorflow_model_analysis/metrics/confusion_matrix_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@
from tensorflow_model_analysis.metrics import confusion_matrix_metrics
from tensorflow_model_analysis.metrics import metric_types
from tensorflow_model_analysis.metrics import metric_util
from tensorflow_model_analysis.metrics import test_util

_TF_MAJOR_VERSION = int(tf.version.VERSION.split('.')[0])
_TRUE_POISITIVE = (1, 1)
_TRUE_NEGATIVE = (0, 0)


class ConfusionMatrixMetricsTest(testutil.TensorflowModelAnalysisTest,
parameterized.TestCase):
class ConfusionMatrixMetricsTest(
testutil.TensorflowModelAnalysisTest,
test_util.TestCase,
parameterized.TestCase,
):

@parameterized.named_parameters(
(
Expand Down Expand Up @@ -953,6 +957,68 @@ def check_result(got):

util.assert_that(result, check_result, label='result')

@parameterized.named_parameters(
(
'false_positives',
confusion_matrix_metrics.FalsePositiveFeatureSampler(
threshold=0.5, feature_key='example_id', sample_size=2
),
'false_positive_feature_sampler',
np.array(['example1', 'example2'], dtype=str),
),
(
'false_negatives',
confusion_matrix_metrics.FalseNegativeFeatureSampler(
threshold=0.5, feature_key='example_id', sample_size=2
),
'false_negative_feature_sampler',
np.array(['example3', 'example4'], dtype=str),
),
)
def testConfusionMatrixFeatureSamplers(
self, metric, expected_metric_name, expected_value
):
# false positive
example1 = {
'labels': np.array([0.0]),
'predictions': np.array([1.0]),
'example_weights': np.array([1.0]),
'features': {'example_id': np.array(['example1'])},
}
# false positive
example2 = {
'labels': np.array([0.0]),
'predictions': np.array([1.0]),
'example_weights': np.array([1.0]),
'features': {'example_id': np.array(['example2'])},
}
# false negative
example3 = {
'labels': np.array([1.0]),
'predictions': np.array([0.0]),
'example_weights': np.array([1.0]),
'features': {'example_id': np.array(['example3'])},
}
# false negative
example4 = {
'labels': np.array([1.0]),
'predictions': np.array([0.0]),
'example_weights': np.array([1.0]),
'features': {'example_id': np.array(['example4'])},
}

expected_metrics = {
metric_types.MetricKey(
name=expected_metric_name, example_weighted=True
): expected_value,
}
self.assertDerivedMetricsEqual(
expected_metrics=expected_metrics,
extracts=[example1, example2, example3, example4],
metric=metric,
enable_debug_print=True,
)


if __name__ == '__main__':
tf.compat.v1.enable_v2_behavior()
Expand Down

0 comments on commit 84a9c2b

Please sign in to comment.