Skip to content

Commit

Permalink
Add a new typing for memory banks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 650013164
Change-Id: Ie896345ec66f37807aa132193cfd7197c2b911cf
  • Loading branch information
duenez authored and Copybara-Service committed Jul 7, 2024
1 parent 9154bc6 commit be58493
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions concordia/typing/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The abstract class for a memory."""

import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol


class MemoryScorer(Protocol):
"""Typing definition for a memory scorer function."""

def __call__(self, query: str, text: str, **metadata: Any) -> float:
"""Returns a score for a memory (text and metadata) given the query.
Args:
query: The query to use for retrieval.
text: The text of the memory.
**metadata: The metadata of the memory.
"""


class MemoryBank(metaclass=abc.ABCMeta):
"""Base class for memory banks."""

@abc.abstractmethod
def add(self, text: str, metadata: Mapping[str, Any]) -> None:
"""Adds a memory (in the form of text) to the memory bank.
The memory bank might add extra metadata to the memory.
Args:
text: The text to add to the memory bank.
metadata: The metadata associated with the memory.
"""
raise NotImplementedError()

def extend(self, texts: Sequence[str], metadata: Mapping[str, Any]) -> None:
"""Adds a sequence of memories (in the form of text) to the memory bank.
All memories will be added with the same metadata. The memory bank might add
extra metadata to the memories.
Args:
texts: The texts to add to the memory bank.
metadata: The metadata associated with all the memories.
"""
for text in texts:
self.add(text, metadata)

@abc.abstractmethod
def retrieve(
self,
query: str,
scoring_fn: MemoryScorer,
limit: int,
) -> Sequence[tuple[str, float]]:
"""Retrieves memories from the memory bank using the given scoring function.
This function retrieves the memories from the memory bank that are most
relevant to the given query, according to the scoring function. The scoring
function is a function that takes the query, a memory (in the form of text),
and a dictionary of metadata and returns a score for the memory. The higher
the score, the more relevant the memory is to the query.
Args:
query: The query to use for retrieval.
scoring_fn: The scoring function to use.
limit: The maximum number of memories to retrieve. If negative, all
memories will be retrieved.
Returns:
A list of memories (in the form of text) and their scores that are most
relevant to the `query`. This list will be of at most `limit` elements,
unless `limit` is negative, in which case all memories will be returned.
"""
raise NotImplementedError()

0 comments on commit be58493

Please sign in to comment.