Skip to content

Commit

Permalink
Migrate min_label_position.py to not depend on estimator related libr…
Browse files Browse the repository at this point in the history
…aries.

PiperOrigin-RevId: 649238064
  • Loading branch information
zhouhao138 authored and tfx-copybara committed Jul 4, 2024
1 parent 9160649 commit dda6c6d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from typing import Any, Dict, Iterable, NamedTuple

import apache_beam as beam
from tensorflow_model_analysis.eval_saved_model import constants as eval_saved_model_constants
from tensorflow_model_analysis.eval_saved_model import util as eval_saved_model_util
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.evaluators.query_metrics import query_types
from tensorflow_model_analysis.post_export_metrics import metric_keys
from tensorflow_model_analysis.utils import util

_State = NamedTuple('_State', [('min_pos_sum', float), ('weight_sum', float)])

Expand Down Expand Up @@ -70,8 +70,7 @@ def __init__(self, label_key: str, weight_key: str):
# If label_key is set to the empty string, the user is telling us
# that their Estimator returns a labels Tensor rather than a
# dictionary. Set the key to the magic key we use in that case.
self._label_key = eval_saved_model_util.default_dict_key(
eval_saved_model_constants.LABELS_NAME)
self._label_key = util.default_dict_key(constants.LABELS_KEY)
else:
self._label_key = label_key
self._weight_key = weight_key
Expand Down
15 changes: 12 additions & 3 deletions tensorflow_model_analysis/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
VALUES_SUFFIX = 'values'


def default_dict_key(prefix: str) -> str:
"""Returns the default key to use with a dict associated with given prefix."""
return KEY_SEPARATOR + prefix


def is_sparse_or_ragged_tensor_value(tensor: Any) -> bool:
"""Returns true if sparse or ragged tensor."""
return (isinstance(tensor, types.SparseTensorValue) or
Expand Down Expand Up @@ -838,10 +843,14 @@ def merge_lists(target: types.Extracts) -> types.Extracts:
) from e
return {k: merge_lists(v) for k, v in target.items()}
elif (
target and
np.any([isinstance(t, tf.compat.v1.SparseTensorValue) for t in target])
target
and np.any(
[isinstance(t, tf.compat.v1.SparseTensorValue) for t in target]
)
or np.any(
[isinstance(target[0], types.SparseTensorValue) for t in target])):
[isinstance(target[0], types.SparseTensorValue) for _ in target]
)
):
t = tf.compat.v1.sparse_concat(
0,
[tf.sparse.expand_dims(to_tensorflow_tensor(t), 0) for t in target],
Expand Down

0 comments on commit dda6c6d

Please sign in to comment.