Skip to content

Commit

Permalink
Introduce LM cache.
Browse files Browse the repository at this point in the history
- Add interface `lf.LMCache`
- Implement `lf.llms.cache.InMemory`.
- Support custom `key` function and `ttl`.

NOTE: cache key is determined by (model_id, sampling_options, prompt).

Usage:
```python

lm = lf.llms.Gpt35(cache=lf.llms.cache.InMemory())
print(lf.LangFunc('Intro to the U.S.A', lm=lm))
```
PiperOrigin-RevId: 567398279
  • Loading branch information
daiyip authored and langfun authors committed Sep 21, 2023
1 parent 0fc31cb commit 32b6c5e
Show file tree
Hide file tree
Showing 11 changed files with 369 additions and 9 deletions.
1 change: 1 addition & 0 deletions langfun/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
from langfun.core.language_model import LMSample
from langfun.core.language_model import LMSamplingOptions
from langfun.core.language_model import LMSamplingResult
from langfun.core.language_model import LMCache

# Components for building agents.
from langfun.core.memory import Memory
Expand Down
4 changes: 2 additions & 2 deletions langfun/core/langfunc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def test_call(self):
"LangFunc(template_str='Hello', clean=True, returns=None, "
'lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0, '
'max_tokens=1024, n=1, top_k=40, top_p=None, random_seed=None), '
'timeout=120.0, max_attempts=5, debug=False), input_transform=None, '
'output_transform=None)',
'cache=None, timeout=120.0, max_attempts=5, debug=False), '
'input_transform=None, output_transform=None)',
)

l = LangFunc('Hello')
Expand Down
81 changes: 76 additions & 5 deletions langfun/core/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import abc
import time
from typing import Annotated
from typing import Annotated, Any
from langfun.core import component
from langfun.core import console
from langfun.core import message as message_lib
Expand Down Expand Up @@ -78,6 +78,34 @@ class LMSamplingOptions(component.Component):
int | None, 'A fixed random seed used during model inference.'
] = None

def cache_key(self) -> tuple[Any, ...]:
"""Returns a tuple of current values as cache key."""
return (
self.temperature,
self.max_tokens,
self.n,
self.top_k,
self.top_p,
self.random_seed
)


class LMCache(pg.Object):
"""Interface for LM cache."""

@abc.abstractmethod
def get(self,
lm: 'LanguageModel',
prompt: message_lib.Message) -> LMSamplingResult | None:
"""Gets the cached result of a prompt generated by a language model."""

@abc.abstractmethod
def put(self,
lm: 'LanguageModel',
prompt: message_lib.Message,
result: LMSamplingResult) -> None:
"""Puts the result of a prompt generated by a language model in cache."""


class LanguageModel(component.Component):
"""Interface of a language model.
Expand All @@ -91,6 +119,13 @@ class LanguageModel(component.Component):

sampling_options: LMSamplingOptions = LMSamplingOptions()

cache: Annotated[
LMCache | None,
(
'Sampling cache. If None, no cache will be used.'
)
] = None

timeout: Annotated[
float | None, 'Timeout in seconds. If None, there is no timeout.'
] = 120.0
Expand Down Expand Up @@ -130,15 +165,51 @@ def _on_bound(self):
super()._on_bound()
self._call_counter = 0

@property
def model_id(self) -> str:
"""Returns a string to identify the model."""
return self.__class__.__name__

def sample(self,
prompts: list[str | message_lib.Message],
**kwargs) -> list[LMSamplingResult]:
"""Samples one or multiple prompts."""
prompts = [message_lib.UserMessage.from_value(p) for p in prompts]

with component.context(override_attrs=True, **kwargs):
return self._sample([
message_lib.UserMessage.from_value(p)
for p in prompts
])
if self.cache is None:
return self._sample(prompts)
else:
return self._sample_with_cache_lookup(prompts)

def _sample_with_cache_lookup(
self, prompts: list[str | message_lib.Message]) -> list[LMSamplingResult]:
"""Sample with cache lookup."""
assert self.cache is not None

results = [None] * len(prompts)
requests, request_to_result_index = [], {}

# Perform cache lookup and figure out sampling requests to make.
for i, prompt in enumerate(prompts):
r = self.cache.get(self, prompt)
if r is None:
request_to_result_index[len(requests)] = i
requests.append(prompt)
else:
results[i] = r.clone()

# Sample non-cache-hit prompts.
requested_results = self._sample(requests)
assert len(requested_results) == len(requests), (
requests, requested_results)

# Combine cached results and newly requested results.
for i, (prompt, result) in enumerate(zip(requests, requested_results)):
results[request_to_result_index[i]] = result
self.cache.put(self, prompt, result)

return results # pytype: disable=bad-return-type

@abc.abstractmethod
def _sample(
Expand Down
63 changes: 63 additions & 0 deletions langfun/core/language_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,49 @@ def fake_sample(prompts):
)(prompts)


class SimpleCache(lm_lib.LMCache):

def _on_bound(self):
super()._on_bound()
self._cache = {}
self.cache_hit = 0

def get(self, lm, prompt):
del lm
r = self._cache.get(prompt.text)
if r is not None:
self.cache_hit += 1
return r

def put(self, lm, prompt, result):
self._cache[prompt.text] = result

@property
def num_records(self):
return len(self._cache)


class LMSamplingOptionsTest(unittest.TestCase):
"""Tests for LMSamplingOptions."""

def test_cache_key(self):
options = lm_lib.LMSamplingOptions()
key1 = options.cache_key()
self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))
with options.override(temperature=1.0, max_tokens=256):
key2 = options.cache_key()
self.assertEqual(key2, (1.0, 256, 1, 40, None, None))

# Make sure key1 does not change upon override.
self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))


class LanguageModelTest(unittest.TestCase):
"""Tests for LanguageModel."""

def test_init(self):
lm = MockModel(1, temperature=0.5, top_k=2, max_attempts=2)
self.assertEqual(lm.model_id, 'MockModel')
self.assertEqual(lm.failures_before_attempt, 1)
self.assertEqual(lm.sampling_options.temperature, 0.5)
self.assertEqual(lm.sampling_options.top_k, 2)
Expand Down Expand Up @@ -117,6 +155,31 @@ def test_call(self):
# Test override individual flags within sampling_options.
self.assertEqual(lm('foo', top_k=2), 'foo' * 2)

def test_using_cache(self):
cache = SimpleCache()
lm = MockModel(cache=cache, top_k=1)
self.assertEqual(
lm.sample(prompts=['foo', 'bar']),
[
lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=0.0)]),
lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=0.0)]),
])

self.assertEqual(cache.cache_hit, 0)
self.assertEqual(cache.num_records, 2)
self.assertEqual(
lm.sample(prompts=['foo', 'baz'], temperature=1.0),
[
lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=0.0)]),
lm_lib.LMSamplingResult([lm_lib.LMSample('baz', score=1.0)]),
])
self.assertEqual(cache.cache_hit, 1)
self.assertEqual(cache.num_records, 3)

self.assertEqual(lm('baz', temperature=1.0), 'baz')
self.assertEqual(cache.cache_hit, 2)
self.assertEqual(cache.num_records, 3)

def test_retry(self):
lm = MockModel(
failures_before_attempt=1, top_k=1,
Expand Down
3 changes: 3 additions & 0 deletions langfun/core/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,8 @@

# Placeholder for Google-internal imports.

# Include cache as sub-module.
from langfun.core.llms import cache

# pylint: enable=g-bad-import-order
# pylint: enable=g-importing-member
26 changes: 26 additions & 0 deletions langfun/core/llms/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2023 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""langfun LLM cache implementations."""

# pylint: disable=g-importing-member
# pylint: disable=g-bad-import-order

from langfun.core.llms.cache.base import LMCacheBase
from langfun.core.llms.cache.base import LMCacheEntry

from langfun.core.llms.cache.in_memory import InMemory


# pylint: enable=g-bad-import-order
# pylint: enable=g-importing-member
86 changes: 86 additions & 0 deletions langfun/core/llms/cache/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2023 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LM cache base."""

import abc
import dataclasses
import datetime
from typing import Annotated, Any, Callable
import langfun.core as lf


@dataclasses.dataclass(frozen=True)
class LMCacheEntry:
result: lf.LMSamplingResult
expire: datetime.datetime | None


class LMCacheBase(lf.LMCache):
"""The common LMCache base."""

key: Annotated[
Callable[[lf.LanguageModel, lf.Message], Any] | None,
(
'A callable ojbect used for computing the key (hashable structure) '
'from the language model used and input prompt. If None, a default '
'key will be used, which are sensitive to the model id, sampling '
'options and the input prompt.'
)
] = None

ttl: Annotated[
int | None,
(
'Time-to-live in seconds.'
)
] = None

def _on_bound(self):
super()._on_bound()
self._key = self.key or default_key

def get(self,
lm: lf.LanguageModel,
prompt: lf.Message) -> lf.LMSamplingResult | None:
"""Gets the cached result of a prompt generated by a language model."""
entry = self._get(self._key(lm, prompt))
if entry is None:
return None
if entry.expire is not None and entry.expire < datetime.datetime.now():
return None
return entry.result

def put(self,
lm: lf.LanguageModel,
prompt: lf.Message,
result: lf.LMSamplingResult) -> None:
"""Puts the result of a prompt generated by a language model in cache."""
expire = None
if self.ttl:
expire = datetime.datetime.now() + datetime.timedelta(seconds=self.ttl)
entry = LMCacheEntry(result, expire)
self._put(self._key(lm, prompt), entry)

@abc.abstractmethod
def _get(self, key: Any) -> LMCacheEntry | None:
"""Returns a LM cache entry associated with the key."""

@abc.abstractmethod
def _put(self, key: Any, entry: LMCacheEntry) -> None:
"""Puts a LM cache entry associated with the key."""


def default_key(lm: lf.LanguageModel, prompt: lf.Message) -> Any:
"""Default key for LM cache."""
return (lm.model_id, lm.sampling_options.cache_key(), prompt.text)
40 changes: 40 additions & 0 deletions langfun/core/llms/cache/in_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2023 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""In-memory LM cache."""

from typing import Any
from langfun.core.llms.cache import base


class InMemory(base.LMCacheBase):
"""In memory cache."""

def _get(self, key: Any) -> base.LMCacheEntry | None:
"""Returns a LM cache entry associated with the key."""
return _CACHE_MEMORY.get(key, None)

def _put(self, key: Any, entry: base.LMCacheEntry) -> None:
"""Puts a LM cache entry associated with the key."""
_CACHE_MEMORY[key] = entry

def reset(self) -> None:
"""Resets the cache."""
_CACHE_MEMORY.clear()


# NOTE(daiyip): We install a process-level cache store, so different InMemory()
# object could access the same memory. This is not a problem across different
# language models, since the `model_id` of the language model is included as a
# part of the cache key.
_CACHE_MEMORY = {}
Loading

0 comments on commit 32b6c5e

Please sign in to comment.