Skip to content

Commit

Permalink
Update CachingModelWrapper.predict() to use the cache instead of bypa…
Browse files Browse the repository at this point in the history
…ssing.

Simplifies CachingModelWrapper by removing the dataset_name and using the model name in the CacheKey.

Updates CachingModelWrapper.predict_with_metadata() to act as a passthrough to predict(), now that predict() uses the cache.

This will simplify and stabilize LIT's behavior as we move toward the removal of the Model.predict_with_metadata() API from models.

PiperOrigin-RevId: 549716363
  • Loading branch information
nadah09 authored and LIT team committed Jul 20, 2023
1 parent e020faa commit e9ce692
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 62 deletions.
75 changes: 28 additions & 47 deletions lit_nlp/lib/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""Miscellaneous helper functions."""

import functools
import os
import pickle
import threading
Expand Down Expand Up @@ -206,6 +205,7 @@ def __init__(self,
cache_dir: if given, will load/save data to disk
"""
super().__init__(model)
self._name = name
self._log_prefix = f"CachingModelWrapper '{name:s}'"
self._cache = PredsCache(
name, model.supports_concurrent_predictions, cache_dir)
Expand All @@ -217,23 +217,21 @@ def load_cache(self):
def save_cache(self):
self._cache.save_to_disk()

def key_fn(self, d, group_name) -> CacheKey:
if d["id"] == "": # pylint: disable=g-explicit-bool-comparison
logging.warning("Found empty example ID - using empty cache ID.")
def key_fn(self, d) -> CacheKey:
if not (d_id := d.get("_id")):
logging.warning(
"Found empty or missing example ID - using empty cache ID.")
return None
return (group_name, d["id"])
return (self._name, d_id)

##
# For internal use
def fit_transform_with_metadata(self,
indexed_inputs: list[JsonDict],
dataset_name: str = ""):
def fit_transform_with_metadata(self, indexed_inputs: list[JsonDict], **kw):
"""For use with UMAP and other preprocessing transforms."""
outputs = list(self.wrapped.fit_transform_with_metadata(indexed_inputs))
key_fn = functools.partial(self.key_fn, group_name=dataset_name)
with self._cache.lock:
for i, output in enumerate(outputs):
self._cache.put(output, key_fn(indexed_inputs[i]))
self._cache.put(output, self.key_fn(indexed_inputs[i]))
return outputs

def predict_minibatch(self, *args, **kw):
Expand All @@ -243,40 +241,14 @@ def predict_minibatch(self, *args, **kw):
"to access cache via example IDs.")
return self.wrapped.predict_minibatch(*args, **kw)

def predict(self, *args, **kw):
logging.warning(
"CachingModelWrapper.predict() bypasses the cache - "
"if this is not intended, use predict_with_metadata() instead "
"to access cache via example IDs.")
return self.wrapped.predict(*args, **kw)

def predict_with_metadata(self, *args, **kw):
"""As predict(), but inputs are IndexedInput."""
results = self._predict_with_metadata(*args, **kw)
return results

def _get_results_from_cache(self, input_keys: list[str]):
with self._cache.lock:
return [self._cache.get(input_key) for input_key in input_keys] # pytype: disable=wrong-arg-types # always-use-return-annotations

def _predict_with_metadata(
self,
indexed_inputs: list[JsonDict],
dataset_name: Optional[str] = None,
progress_indicator: Optional[ProgressIndicator] = lambda x: x,
**kw) -> list[JsonDict]:
"""As predict(), but inputs are IndexedInput."""
# TODO(lit-dev): consider moving this to example level
# (null keys skip cache), and removing this codepath.
if dataset_name is None:
logging.info("\n\nCache disabled for current call.\n\n")
results = list(self.wrapped.predict_with_metadata(indexed_inputs))
return results

key_fn = functools.partial(self.key_fn, group_name=dataset_name)
def predict(self,
inputs: Iterable[JsonDict],
progress_indicator: Optional[ProgressIndicator] = lambda x: x,
**kw) -> list[JsonDict]:

inputs_as_list = list(inputs)
# Try to get results from the cache.
input_keys = [key_fn(d) for d in indexed_inputs]
input_keys = [self.key_fn(d) for d in inputs_as_list]
if self._cache.pred_lock_key(input_keys):
with self._cache.get_pred_lock(input_keys):
results = self._get_results_from_cache(input_keys)
Expand All @@ -286,10 +258,8 @@ def _predict_with_metadata(
# Make a single list of everything that wasn't found in the cache,
# to actually run the model on these inputs.
miss_idxs = [i for i, v in enumerate(results) if v is None]
misses = [indexed_inputs[i] for i in miss_idxs]
misses = [inputs_as_list[i] for i in miss_idxs]
if misses:
logging.info("%s: misses (dataset=%s): %s", self._log_prefix,
dataset_name, str([miss["id"] for miss in misses]))
logging.info("%s: %d misses out of %d inputs", self._log_prefix,
len(miss_idxs), len(results))
else:
Expand All @@ -298,7 +268,7 @@ def _predict_with_metadata(

with self._cache.get_pred_lock(input_keys):
model_preds = list(
self.wrapped.predict_with_metadata(progress_indicator(misses)))
self.wrapped.predict(progress_indicator(misses)))
logging.info("Received %d predictions from model", len(model_preds))

if len(model_preds) != len(misses):
Expand All @@ -308,10 +278,21 @@ def _predict_with_metadata(
# Merge results back into the output list.
with self._cache.lock:
for i, orig_idx in enumerate(miss_idxs):
self._cache.put(model_preds[i], key_fn(indexed_inputs[orig_idx]))
self._cache.put(model_preds[i], self.key_fn(inputs_as_list[orig_idx]))
results[orig_idx] = model_preds[i]

# Remove the prediction lock from the cache as the request is complete
self._cache.delete_pred_lock(input_keys)

return results

# TODO(b/171513556): remove this method once we no longer need to override
# ModelWrapper.predict_with_metadata()
def predict_with_metadata(self, indexed_inputs: Iterable[JsonDict], **kw):
"""As predict(), but inputs are IndexedInput."""
results = self.predict((ex["data"] for ex in indexed_inputs), **kw)
return results

def _get_results_from_cache(self, input_keys: list[CacheKey]):
with self._cache.lock:
return [self._cache.get(input_key) for input_key in input_keys]
39 changes: 24 additions & 15 deletions lit_nlp/lib/caching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def test_caching_model_wrapper_no_dataset_skip_cache(self):
def test_caching_model_wrapper_use_cache(self):
model = testing_utils.IdentityRegressionModelForTesting()
wrapper = caching.CachingModelWrapper(model, "test")
examples = [{"data": {"val": 1}, "id": "id_to_cache"}]
results = wrapper.predict_with_metadata(examples, "dataset")
examples = [{"data": {"val": 1, "_id": "id_to_cache"}, "id": "id_to_cache"}]
results = wrapper.predict_with_metadata(examples)
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])
results = wrapper.predict_with_metadata(examples, "dataset")
results = wrapper.predict_with_metadata(examples)
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])
self.assertEmpty(wrapper._cache._pred_locks)
Expand All @@ -60,7 +60,7 @@ def test_caching_model_wrapper_not_cached(self):
model = testing_utils.IdentityRegressionModelForTesting()
wrapper = caching.CachingModelWrapper(model, "test")
examples = [{"data": {"val": 1}, "id": "my_id"}]
results = wrapper.predict_with_metadata(examples, "dataset")
results = wrapper.predict_with_metadata(examples)
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])
examples = [{"data": {"val": 2}, "id": "other_id"}]
Expand All @@ -71,32 +71,41 @@ def test_caching_model_wrapper_not_cached(self):
def test_caching_model_wrapper_mixed_list(self):
model = testing_utils.IdentityRegressionModelForTesting()
wrapper = caching.CachingModelWrapper(model, "test")
examples = [{"data": {"val": 1}, "id": "my_id"}]
results = wrapper.predict_with_metadata(examples, "dataset")
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])

examples = [
{
"data": {
"val": 0
"val": 0,
"_id": "zeroth_id"
},
"id": "first_id"
"id": "zeroth_id"
},
{
"data": {
"val": 1
"val": 1,
"_id": "first_id"
},
"id": "my_id"
"id": "first_id"
},
{
"data": {
"val": 2
"val": 2,
"_id": "second_id"
},
"id": "last_id"
"id": "second_id"
},
]
results = wrapper.predict_with_metadata(examples, "dataset")
subset = examples[:1]

# Run the CachingModelWrapper over a subset of examples
results = wrapper.predict_with_metadata(subset)
self.assertEqual(1, model.count)
self.assertEqual({"score": 0}, results[0])

# Now, run the CachingModelWrapper over all of the examples. This should
# only pass the examples that were not in subset to the wrapped model, and
# the total number of inputs processed by the wrapped model should be 3
results = wrapper.predict_with_metadata(examples)
self.assertEqual(3, model.count)
self.assertEqual({"score": 0}, results[0])
self.assertEqual({"score": 1}, results[1])
Expand Down

0 comments on commit e9ce692

Please sign in to comment.