Skip to content

Commit

Permalink
Replace eval_saved_model/testutil with utils/test_util in several tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648493608
  • Loading branch information
zhouhao138 authored and tfx-copybara committed Jul 1, 2024
1 parent 3f316e0 commit c6cf100
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from apache_beam.testing import util
import tensorflow as tf
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.evaluators import analysis_table_evaluator
from tensorflow_model_analysis.utils import test_util


class AnalysisTableEvaulatorTest(testutil.TensorflowModelAnalysisTest):
class AnalysisTableEvaulatorTest(test_util.TensorflowModelAnalysisTest):

def testIncludeFilter(self):
with beam.Pipeline() as pipeline:
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_model_analysis/evaluators/evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
"""Test for evaluator."""

import tensorflow as tf
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.evaluators import evaluator
from tensorflow_model_analysis.extractors import extractor
from tensorflow_model_analysis.utils import test_util


class EvaluatorTest(testutil.TensorflowModelAnalysisTest):
class EvaluatorTest(test_util.TensorflowModelAnalysisTest):

def testVerifyEvaluatorRaisesValueError(self):
extractors = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@

from absl.testing import parameterized
import tensorflow as tf

from tensorflow_model_analysis.api import types
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.evaluators import metrics_validator
from tensorflow_model_analysis.metrics import metric_types
from tensorflow_model_analysis.proto import config_pb2
from tensorflow_model_analysis.proto import validation_result_pb2
from tensorflow_model_analysis.slicer import slicer_lib as slicer
from tensorflow_model_analysis.utils import test_util
from google.protobuf import text_format

# Tests involiving slices: (<test_name>, <slice_config> , <slice_key>)
Expand Down Expand Up @@ -95,8 +94,9 @@
], ((('feature1', 'value1'),), (('feature3', 'value3'),)))


class MetricsValidatorTest(testutil.TensorflowModelAnalysisTest,
parameterized.TestCase):
class MetricsValidatorTest(
test_util.TensorflowModelAnalysisTest, parameterized.TestCase
):

def testValidateMetricsInvalidThreshold(self):
eval_config = config_pb2.EvalConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@
import tensorflow as tf
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.api import model_eval_lib
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.extractors import features_extractor
from tensorflow_model_analysis.proto import config_pb2
from tensorflow_model_analysis.utils import test_util
from tfx_bsl.tfxio import tf_example_record

from google.protobuf import text_format
from tensorflow_metadata.proto.v0 import schema_pb2


class FeaturesExtractorTest(testutil.TensorflowModelAnalysisTest,
parameterized.TestCase):
class FeaturesExtractorTest(
test_util.TensorflowModelAnalysisTest, parameterized.TestCase
):

def test_features_extractor_no_features(self):
model_spec = config_pb2.ModelSpec()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from apache_beam.testing import util
import numpy as np
import tensorflow as tf
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.metrics import metric_types
from tensorflow_model_analysis.metrics import metric_util
from tensorflow_model_analysis.metrics import multi_class_confusion_matrix_metrics
from tensorflow_model_analysis.utils import test_util


class MultiClassConfusionMatrixMetricsTest(testutil.TensorflowModelAnalysisTest,
parameterized.TestCase):
class MultiClassConfusionMatrixMetricsTest(
test_util.TensorflowModelAnalysisTest, parameterized.TestCase
):

@parameterized.named_parameters(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from apache_beam.testing import util
import numpy as np
import tensorflow as tf
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.metrics import metric_types
from tensorflow_model_analysis.metrics import metric_util
from tensorflow_model_analysis.metrics import multi_class_confusion_matrix_plot
from tensorflow_model_analysis.utils import test_util


class MultiClassConfusionMatrixPlotTest(testutil.TensorflowModelAnalysisTest,
parameterized.TestCase):
class MultiClassConfusionMatrixPlotTest(
test_util.TensorflowModelAnalysisTest, parameterized.TestCase
):

def testMultiClassConfusionMatrixPlot(self):
computations = (
Expand Down
7 changes: 4 additions & 3 deletions tensorflow_model_analysis/metrics/tjur_discrimination_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from apache_beam.testing import util
import numpy as np
import tensorflow as tf
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.metrics import metric_util
from tensorflow_model_analysis.metrics import tjur_discrimination
from tensorflow_model_analysis.utils import test_util


class TjurDisriminationTest(testutil.TensorflowModelAnalysisTest,
parameterized.TestCase):
class TjurDisriminationTest(
test_util.TensorflowModelAnalysisTest, parameterized.TestCase
):

@parameterized.named_parameters(
('coefficient_of_discrimination',
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_model_analysis/slicer/slicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
import tensorflow as tf
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.api import types
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.extractors import slice_key_extractor
from tensorflow_model_analysis.post_export_metrics import metric_keys
from tensorflow_model_analysis.proto import config_pb2
from tensorflow_model_analysis.proto import metrics_for_slice_pb2
from tensorflow_model_analysis.slicer import slicer_lib as slicer
from tensorflow_model_analysis.utils import test_util
from tensorflow_model_analysis.utils import util as tfma_util

from google.protobuf import text_format
Expand Down Expand Up @@ -72,7 +72,7 @@ def wrap_fpl(fpl):
}


class SlicerTest(testutil.TensorflowModelAnalysisTest, parameterized.TestCase):
class SlicerTest(test_util.TensorflowModelAnalysisTest, parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_model_analysis/writers/writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

import apache_beam as beam
import tensorflow as tf
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.utils import test_util
from tensorflow_model_analysis.writers import writer


class WriterTest(testutil.TensorflowModelAnalysisTest):
class WriterTest(test_util.TensorflowModelAnalysisTest):

def testWriteIgnoresMissingKeys(self):
with beam.Pipeline() as pipeline:
Expand Down

0 comments on commit c6cf100

Please sign in to comment.