-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 650013164 Change-Id: Ie896345ec66f37807aa132193cfd7197c2b911cf
- Loading branch information
Showing
1 changed file
with
89 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright 2023 DeepMind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""The abstract class for a memory.""" | ||
|
||
import abc | ||
from collections.abc import Mapping, Sequence | ||
from typing import Any, Protocol | ||
|
||
|
||
class MemoryScorer(Protocol): | ||
"""Typing definition for a memory scorer function.""" | ||
|
||
def __call__(self, query: str, text: str, **metadata: Any) -> float: | ||
"""Returns a score for a memory (text and metadata) given the query. | ||
Args: | ||
query: The query to use for retrieval. | ||
text: The text of the memory. | ||
**metadata: The metadata of the memory. | ||
""" | ||
|
||
|
||
class MemoryBank(metaclass=abc.ABCMeta): | ||
"""Base class for memory banks.""" | ||
|
||
@abc.abstractmethod | ||
def add(self, text: str, metadata: Mapping[str, Any]) -> None: | ||
"""Adds a memory (in the form of text) to the memory bank. | ||
The memory bank might add extra metadata to the memory. | ||
Args: | ||
text: The text to add to the memory bank. | ||
metadata: The metadata associated with the memory. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def extend(self, texts: Sequence[str], metadata: Mapping[str, Any]) -> None: | ||
"""Adds a sequence of memories (in the form of text) to the memory bank. | ||
All memories will be added with the same metadata. The memory bank might add | ||
extra metadata to the memories. | ||
Args: | ||
texts: The texts to add to the memory bank. | ||
metadata: The metadata associated with all the memories. | ||
""" | ||
for text in texts: | ||
self.add(text, metadata) | ||
|
||
@abc.abstractmethod | ||
def retrieve( | ||
self, | ||
query: str, | ||
scoring_fn: MemoryScorer, | ||
limit: int, | ||
) -> Sequence[tuple[str, float]]: | ||
"""Retrieves memories from the memory bank using the given scoring function. | ||
This function retrieves the memories from the memory bank that are most | ||
relevant to the given query, according to the scoring function. The scoring | ||
function is a function that takes the query, a memory (in the form of text), | ||
and a dictionary of metadata and returns a score for the memory. The higher | ||
the score, the more relevant the memory is to the query. | ||
Args: | ||
query: The query to use for retrieval. | ||
scoring_fn: The scoring function to use. | ||
limit: The maximum number of memories to retrieve. If negative, all | ||
memories will be retrieved. | ||
Returns: | ||
A list of memories (in the form of text) and their scores that are most | ||
relevant to the `query`. This list will be of at most `limit` elements, | ||
unless `limit` is negative, in which case all memories will be returned. | ||
""" | ||
raise NotImplementedError() |