Skip to content

Commit

Permalink
Add sparse and ragged options to TextVectorization output
Browse files Browse the repository at this point in the history
Sparse will only apply when output_mode is "one_hot", "multi_hot",
"count", or "tf_idf" and the last dimension of the output contains
bins for every element in the vocab. Sparse will be a more efficient
output option when vocab size is large.

Ragged will only apply when output_mode is "int", and will output
a ragged tensor after string splitting where the final dimension
contains ragged vocab indices of variable length.

PiperOrigin-RevId: 391415472
  • Loading branch information
mattdangerw authored and tensorflower-gardener committed Aug 18, 2021
1 parent 3df49f8 commit 05048b7
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'max_tokens\', \'standardize\', \'split\', \'ngrams\', \'output_mode\', \'output_sequence_length\', \'pad_to_max_tokens\', \'vocabulary\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'lower_and_strip_punctuation\', \'whitespace\', \'None\', \'int\', \'None\', \'False\', \'None\'], "
argspec: "args=[\'self\', \'max_tokens\', \'standardize\', \'split\', \'ngrams\', \'output_mode\', \'output_sequence_length\', \'pad_to_max_tokens\', \'vocabulary\', \'sparse\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'lower_and_strip_punctuation\', \'whitespace\', \'None\', \'int\', \'None\', \'False\', \'None\', \'False\', \'False\'], "
}
member_method {
name: "adapt"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'max_tokens\', \'standardize\', \'split\', \'ngrams\', \'output_mode\', \'output_sequence_length\', \'pad_to_max_tokens\', \'vocabulary\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'lower_and_strip_punctuation\', \'whitespace\', \'None\', \'int\', \'None\', \'False\', \'None\'], "
argspec: "args=[\'self\', \'max_tokens\', \'standardize\', \'split\', \'ngrams\', \'output_mode\', \'output_sequence_length\', \'pad_to_max_tokens\', \'vocabulary\', \'sparse\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'lower_and_strip_punctuation\', \'whitespace\', \'None\', \'int\', \'None\', \'False\', \'None\', \'False\', \'False\'], "
}
member_method {
name: "adapt"
Expand Down
33 changes: 20 additions & 13 deletions keras/layers/preprocessing/index_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ class IndexLookup(base_preprocessing_layer.PreprocessingLayer):
padded to `max_tokens` even if the number of unique tokens in the
vocabulary is less than max_tokens, resulting in a tensor of shape
[batch_size, max_tokens] regardless of vocabulary size. Defaults to False.
sparse: Boolean. Only applicable to `"multi_hot"` and `"count"` output
modes. If True, returns a `SparseTensor` instead of a dense `Tensor`.
Defaults to False.
sparse: Boolean. Only applicable to `"one_hot"`, `"multi_hot"`, `"count"`
and `"tf-idf"` output modes. If True, returns a `SparseTensor` instead of
a dense `Tensor`. Defaults to False.
"""

def __init__(self,
Expand All @@ -164,16 +164,16 @@ def __init__(self,
# If max_tokens is set, the value must be greater than 1 - otherwise we
# are creating a 0-element vocab, which doesn't make sense.
if max_tokens is not None and max_tokens <= 1:
raise ValueError("If set, `max_tokens` must be greater than 1. "
"You passed `max_tokens={}`".format(max_tokens))
raise ValueError(f"If set, `max_tokens` must be greater than 1. "
f"Received: max_tokens={max_tokens}")

if pad_to_max_tokens and max_tokens is None:
raise ValueError("If pad_to_max_tokens is True, must set `max_tokens`. "
"You passed `max_tokens={}`".format(max_tokens))
raise ValueError(f"If pad_to_max_tokens is True, must set `max_tokens`. "
f"Received: max_tokens={max_tokens}")

if num_oov_indices < 0:
raise ValueError("`num_oov_indices` must be greater than or equal to 0. "
"You passed {}".format(num_oov_indices))
raise ValueError(f"`num_oov_indices` must be greater than or equal to 0. "
f"Received: num_oov_indices={num_oov_indices}")

# Support deprecated names for output_modes.
if output_mode == "binary":
Expand All @@ -188,8 +188,14 @@ def __init__(self,
arg_name="output_mode")

if invert and output_mode != INT:
raise ValueError("`output_mode` must be {} when `invert` is true. You "
"passed {}".format(INT, output_mode))
raise ValueError(f"`output_mode` must be `'int'` when `invert` is true. "
f"Received: output_mode={output_mode}")

if sparse and output_mode == INT:
raise ValueError(f"`sparse` must not be true if `output_mode` is "
f"`'one_hot'`, `'multi_hot'`, `'count'` or `'tf_idf'`. "
f"Received: sparse={sparse} and "
f"output_mode={output_mode}")

self.invert = invert
self.max_tokens = max_tokens
Expand Down Expand Up @@ -369,8 +375,9 @@ def set_vocabulary(self, vocabulary, idf_weights=None):
RuntimeError: If a tensor vocabulary is passed outside of eager execution.
"""
if self.output_mode != TF_IDF and idf_weights is not None:
raise ValueError("`idf_weights` should only be set if output_mode is "
"TF_IDF. output_mode is {}.".format(self.output_mode))
raise ValueError(f"`idf_weights` should only be set if output_mode is "
f"`'tf_idf'`. Received: output_mode={self.output_mode} "
f"and idf_weights={idf_weights}")

if isinstance(vocabulary, str):
if not tf.io.gfile.exists(vocabulary):
Expand Down
2 changes: 1 addition & 1 deletion keras/layers/preprocessing/index_lookup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,7 +1899,7 @@ def test_non_unique_vocab_fails(self):
invert=True)

def test_non_int_output_fails(self):
with self.assertRaisesRegex(ValueError, "`output_mode` must be int"):
with self.assertRaisesRegex(ValueError, "`output_mode` must be `'int'`"):
_ = index_lookup.IndexLookup(
max_tokens=None,
num_oov_indices=1,
Expand Down
35 changes: 30 additions & 5 deletions keras/layers/preprocessing/text_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ class TextVectorization(base_preprocessing_layer.PreprocessingLayer):
tensor containing the string vocbulary terms. If passing a file path, the
file should contain one line per term in the vocabulary. If this argument
is set, there is no need to `adapt` the layer.
ragged: Boolean. Only applicable to `"int"` output mode. If True, returns a
`RaggedTensor` instead of a dense `Tensor`, where each sequence may have a
different length after string splitting. Defaults to False.
sparse: Boolean. Only applicable to `"multi_hot"`, `"count"`, and
`"tf_idf"` output modes. If True, returns a `SparseTensor` instead of a
dense `Tensor`. Defaults to False.
Example:
Expand Down Expand Up @@ -236,6 +242,8 @@ def __init__(self,
output_sequence_length=None,
pad_to_max_tokens=False,
vocabulary=None,
sparse=False,
ragged=False,
**kwargs):

# This layer only applies to string processing, and so should only have
Expand Down Expand Up @@ -283,22 +291,32 @@ def __init__(self,
isinstance(ngrams, int) or
isinstance(ngrams, tuple) and
all(isinstance(item, int) for item in ngrams)):
raise ValueError(("`ngrams` must be None, an integer, or a tuple of "
"integers. Got %s") % (ngrams,))
raise ValueError(f"`ngrams` must be None, an integer, or a tuple of "
f"integers. Received: ngrams={ngrams}")

# 'output_sequence_length' must be one of (None, int) and is only
# set if output_mode is INT.
if (output_mode == INT and not (isinstance(output_sequence_length, int) or
(output_sequence_length is None))):
raise ValueError("`output_sequence_length` must be either None or an "
"integer when `output_mode` is 'int'. "
"Got %s" % output_sequence_length)
raise ValueError(f"`output_sequence_length` must be either None or an "
f"integer when `output_mode` is 'int'. Received: "
f"output_sequence_length={output_sequence_length}")

if output_mode != INT and output_sequence_length is not None:
raise ValueError(
f"`output_sequence_length` must not be set if `output_mode` is not "
f"'int'. Received output_sequence_length={output_sequence_length}.")

if ragged and output_mode != INT:
raise ValueError(f"`ragged` must not be true if `output_mode` is "
f"`'int'`. Received: ragged={ragged} and "
f"output_mode={output_mode}")

if ragged and output_sequence_length is not None:
raise ValueError(f"`output_sequence_length` must not be set if ragged "
f"is True. Received: ragged={ragged} and "
f"output_sequence_length={output_sequence_length}")

self._max_tokens = max_tokens
self._standardize = standardize
self._split = split
Expand All @@ -307,6 +325,7 @@ def __init__(self,
self._ngrams = tuple(range(1, ngrams + 1))
else:
self._ngrams = ngrams
self._ragged = ragged

self._output_mode = output_mode
self._output_sequence_length = output_sequence_length
Expand All @@ -330,6 +349,7 @@ def __init__(self,
pad_to_max_tokens=pad_to_max_tokens,
mask_token="",
output_mode=output_mode if output_mode is not None else INT,
sparse=sparse,
has_input_vocabulary=self._has_input_vocabulary)

def compute_output_shape(self, input_shape):
Expand Down Expand Up @@ -387,6 +407,8 @@ def get_config(self):
"output_mode": self._output_mode,
"output_sequence_length": self._output_sequence_length,
"pad_to_max_tokens": self._lookup_layer.pad_to_max_tokens,
"sparse": self._lookup_layer.sparse,
"ragged": self._ragged,
"vocabulary": utils.listify_tensors(vocab),
}
base_config = super(TextVectorization, self).get_config()
Expand Down Expand Up @@ -500,6 +522,9 @@ def call(self, inputs):
if self._output_mode is not INT:
return lookup_data

if self._ragged:
return lookup_data

# If we have a ragged tensor, we can pad during the conversion to dense.
if tf_utils.is_ragged(lookup_data):
shape = lookup_data.shape.as_list()
Expand Down
83 changes: 72 additions & 11 deletions keras/layers/preprocessing/text_vectorization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,31 @@ def test_int_output_densifies_with_zeros(self):
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)

def test_int_output_ragged(self):
vocab_data = ["earth", "wind", "and", "fire"]
# Create an input array that has 5 elements in the first example and 4 in
# the second.
input_array = np.array([["earth wind and also fire"],
["fire and earth michigan"]])
expected_output = tf.ragged.constant([[2, 3, 4, 1, 5], [5, 4, 2, 1]])
expected_output_shape = [None, None]

# The input shape here is explicitly 1 because we're tokenizing.
input_data = keras.Input(shape=(1,), dtype=tf.string)
layer = text_vectorization.TextVectorization(
max_tokens=None,
standardize=None,
split=text_vectorization.SPLIT_ON_WHITESPACE,
output_mode=text_vectorization.INT,
ragged=True)
layer.set_vocabulary(vocab_data)
int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())

model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)

def test_int_output_densifies_with_zeros_and_pads(self):
vocab_data = ["earth", "wind", "and", "fire"]
# Create an input array that has 5 elements in the first example and 4 in
Expand Down Expand Up @@ -970,7 +995,11 @@ def test_int_output_dynamically_strips_and_pads(self):
output_dataset = model.predict(input_array_2)
self.assertAllEqual(expected_output_2, output_dataset)

def test_binary_output_hard_maximum(self):
@parameterized.parameters(
{"sparse": True},
{"sparse": False},
)
def test_multi_hot_output_hard_maximum(self, sparse):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "earth"],
["ohio", "and", "earth", "michigan"]])
Expand All @@ -988,16 +1017,26 @@ def test_binary_output_hard_maximum(self):
standardize=None,
split=None,
output_mode=text_vectorization.MULTI_HOT,
pad_to_max_tokens=True)
pad_to_max_tokens=True,
sparse=sparse)
layer.set_vocabulary(vocab_data)
int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())

model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
if sparse:
expected_output = tf.sparse.from_dense(tf.constant(expected_output))
self.assertAllEqual(expected_output.indices, output_dataset.indices)
self.assertAllEqual(expected_output.values, output_dataset.values)
else:
self.assertAllEqual(expected_output, output_dataset)

def test_binary_output_soft_maximum(self):
@parameterized.parameters(
{"sparse": True},
{"sparse": False},
)
def test_multi_hot_output_soft_maximum(self, sparse):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "earth"],
["ohio", "and", "earth", "michigan"]])
Expand All @@ -1015,16 +1054,22 @@ def test_binary_output_soft_maximum(self):
standardize=None,
split=None,
output_mode=text_vectorization.MULTI_HOT,
pad_to_max_tokens=False)
pad_to_max_tokens=False,
sparse=sparse)
layer.set_vocabulary(vocab_data)
int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())

model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
if sparse:
expected_output = tf.sparse.from_dense(tf.constant(expected_output))
self.assertAllEqual(expected_output.indices, output_dataset.indices)
self.assertAllEqual(expected_output.values, output_dataset.values)
else:
self.assertAllEqual(expected_output, output_dataset)

def test_bag_output_hard_maximum_set_vocabulary_after_build(self):
def test_multi_hot_output_hard_maximum_set_vocabulary_after_build(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "earth"],
["ohio", "and", "earth", "michigan"]])
Expand All @@ -1051,7 +1096,7 @@ def test_bag_output_hard_maximum_set_vocabulary_after_build(self):
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)

def test_bag_output_hard_maximum_adapt_after_build(self):
def test_multi_hot_output_hard_maximum_adapt_after_build(self):
vocab_data = np.array([
"earth", "earth", "earth", "earth", "wind", "wind", "wind", "and",
"and", "fire"
Expand Down Expand Up @@ -1081,7 +1126,7 @@ def test_bag_output_hard_maximum_adapt_after_build(self):
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)

def test_bag_output_hard_maximum_multiple_adapts(self):
def test_multi_hot_output_hard_maximum_multiple_adapts(self):
input_array = np.array([["earth", "wind", "and", "earth"],
["ohio", "and", "earth", "michigan"]])
adapt_data = ["earth", "earth", "earth", "earth", "wind", "wind", "wind"]
Expand Down Expand Up @@ -1119,7 +1164,7 @@ def test_bag_output_hard_maximum_multiple_adapts(self):
self.assertAllEqual(first_expected_output, first_output)
self.assertAllEqual(second_expected_output, second_output)

def test_bag_output_soft_maximum_set_state_after_build(self):
def test_multi_hot_output_soft_maximum_set_state_after_build(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "earth"],
["ohio", "and", "earth", "michigan"]])
Expand Down Expand Up @@ -1525,12 +1570,28 @@ def test_non_int_output_sequence_length_dtype_fails(self):
_ = text_vectorization.TextVectorization(
output_mode="int", output_sequence_length=2.0)

def test_non_none_output_sequence_length_fails_if_output_type_not_int(self):
def test_non_none_output_sequence_length_fails_if_output_mode_not_int(self):
with self.assertRaisesRegex(ValueError,
"`output_sequence_length` must not be set"):
_ = text_vectorization.TextVectorization(
output_mode="count", output_sequence_length=2)

def test_non_none_output_sequence_length_fails_if_ragged_true(self):
with self.assertRaisesRegex(ValueError,
"`output_sequence_length` must not be set"):
_ = text_vectorization.TextVectorization(
ragged=True, output_sequence_length=2)

def test_ragged_true_fails_if_output_mode_not_int(self):
with self.assertRaisesRegex(ValueError, "`ragged` must not be true if"):
_ = text_vectorization.TextVectorization(
ragged=True, output_mode=text_vectorization.MULTI_HOT)

def test_sparse_true_fails_if_output_mode_is_int(self):
with self.assertRaisesRegex(ValueError, "`sparse` must not be true if"):
_ = text_vectorization.TextVectorization(
sparse=True, output_mode=text_vectorization.INT)


# Custom functions for the custom callable serialization test. Declared here
# to avoid multiple registrations from run_all_keras_modes().
Expand Down

0 comments on commit 05048b7

Please sign in to comment.