Skip to content

Commit

Permalink
Add a method to create a test Keras EvalSharedModel.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648770845
  • Loading branch information
zhouhao138 authored and tfx-copybara committed Jul 2, 2024
1 parent 0d3646e commit 1501957
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tensorflow_model_analysis/utils/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from typing import Dict, Iterable, Sequence, Tuple, Union

import tensorflow as tf
from tensorflow_model_analysis.api import model_eval_lib
from tensorflow_model_analysis.api import types
from tensorflow_model_analysis.proto import config_pb2

from tensorflow.core.example import example_pb2

Expand Down Expand Up @@ -175,6 +177,16 @@ def assertDictMatrixRowsAlmostEqual( # pylint: disable=invalid-name
places=places,
msg_prefix='for key %s, row %d: ' % (key, row))

def createKerasTestEvalSharedModel( # pylint: disable=invalid-name
self,
eval_saved_model_path: str,
eval_config: config_pb2.EvalConfig,
) -> types.EvalSharedModel:
"""Create a test Keras EvalSharedModel."""
return model_eval_lib.default_eval_shared_model(
eval_saved_model_path=eval_saved_model_path, eval_config=eval_config
)

def assertSequenceAlmostEqual( # pylint: disable=invalid-name
self,
got_seq: Iterable[Union[float, int]],
Expand Down

0 comments on commit 1501957

Please sign in to comment.