Skip to content

Commit

Permalink
Migrate model_util_test.py from estimator related APIs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 649210604
  • Loading branch information
zhouhao138 authored and tfx-copybara committed Jul 3, 2024
1 parent a70abb8 commit 9160649
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 10 deletions.
19 changes: 10 additions & 9 deletions tensorflow_model_analysis/utils/model_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import tensorflow as tf
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.api import types
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.proto import config_pb2
from tensorflow_model_analysis.utils import model_util
from tensorflow_model_analysis.utils import test_util
from tensorflow_model_analysis.utils import util as tfma_util
from tensorflow_model_analysis.utils.keras_lib import tf_keras
from tfx_bsl.tfxio import tf_example_record
Expand All @@ -47,8 +47,9 @@ def _record_batch_to_extracts(record_batch):
}


class ModelUtilTest(testutil.TensorflowModelAnalysisTest,
parameterized.TestCase):
class ModelUtilTest(
test_util.TensorflowModelAnalysisTest, parameterized.TestCase
):

def createDenseInputsSchema(self):
return text_format.Parse(
Expand Down Expand Up @@ -784,9 +785,10 @@ def testModelSignaturesDoFn(
for model_name in sigs:
if model_name not in eval_shared_models:
eval_shared_models[model_name] = self.createTestEvalSharedModel(
eval_saved_model_path=export_path,
model_path=export_path,
model_name=model_name,
tags=[tf.saved_model.SERVING])
tags=[tf.saved_model.SERVING],
)
model_specs.append(config_pb2.ModelSpec(name=model_name))
schema = self.createDenseInputsSchema() if use_schema else None
tfx_io = tf_example_record.TFExampleBeamRecord(
Expand Down Expand Up @@ -845,10 +847,9 @@ def testModelSignaturesDoFnError(self):
output_keypath = [constants.PREDICTIONS_KEY]
signature_names = {'': [None]}
eval_shared_models = {
'':
self.createTestEvalSharedModel(
eval_saved_model_path=export_path,
tags=[tf.saved_model.SERVING])
'': self.createTestEvalSharedModel(
model_path=export_path, tags=[tf.saved_model.SERVING]
)
}
model_specs = [config_pb2.ModelSpec()]
schema = self.createDenseInputsSchema()
Expand Down
52 changes: 51 additions & 1 deletion tensorflow_model_analysis/utils/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

import math
import tempfile
from typing import Dict, Iterable, Sequence, Tuple, Union
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union

import tensorflow as tf
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.api import model_eval_lib
from tensorflow_model_analysis.api import types
from tensorflow_model_analysis.proto import config_pb2
from tensorflow_model_analysis.utils import model_util

from tensorflow.core.example import example_pb2

Expand Down Expand Up @@ -231,3 +233,51 @@ def assertSparseTensorValueEqual( # pylint: disable=invalid-name
got_sparse_tensor_value.values.dtype)
self.assertEqual(expected_sparse_tensor_value.dense_shape.dtype,
got_sparse_tensor_value.dense_shape.dtype)

def createTestEvalSharedModel( # pylint: disable=invalid-name
self,
model_path: Optional[str] = None,
add_metrics_callbacks: Optional[
List[types.AddMetricsCallbackType]
] = None,
include_default_metrics: Optional[bool] = True,
example_weight_key: Optional[Union[str, Dict[str, str]]] = None,
additional_fetches: Optional[List[str]] = None,
tags: Optional[str] = None,
model_type: Optional[str] = None,
model_name: str = '',
rubber_stamp: Optional[bool] = False,
is_baseline: Optional[bool] = False,
) -> types.EvalSharedModel:
"""Create a test EvalSharedModel."""

if not model_type:
model_type = model_util.get_model_type(None, model_path, tags)
if model_type == constants.TFMA_EVAL:
raise ValueError(
f'Models of type {model_type} are deprecated. Please do not use it'
'for testing.'
)
if not tags:
tags = [tf.saved_model.SERVING]

return types.EvalSharedModel(
model_name=model_name,
model_type=model_type,
model_path=model_path,
add_metrics_callbacks=add_metrics_callbacks,
example_weight_key=example_weight_key,
rubber_stamp=rubber_stamp,
is_baseline=is_baseline,
model_loader=types.ModelLoader(
tags=tags,
construct_fn=model_util.model_construct_fn(
eval_saved_model_path=model_path,
model_type=model_type,
add_metrics_callbacks=add_metrics_callbacks,
include_default_metrics=include_default_metrics,
additional_fetches=additional_fetches,
tags=tags,
),
),
)

0 comments on commit 9160649

Please sign in to comment.