From b035f7807367b5396f089849e750884be6861c18 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 16 Aug 2024 21:36:32 +0200 Subject: [PATCH] Add get/set API to the Context and make it coroutine-safe (#15152) * make the Context coroutine-safe * remove parent property * change API * docs * use context manager and add unit tests --- .../llama_index/core/workflow/context.py | 78 ++++++++++++++++++- .../tests/workflow/test_context.py | 43 ++++++++++ 2 files changed, 117 insertions(+), 4 deletions(-) diff --git a/llama-index-core/llama_index/core/workflow/context.py b/llama-index-core/llama_index/core/workflow/context.py index 865572d568301..9fdbf4a2516dd 100644 --- a/llama-index-core/llama_index/core/workflow/context.py +++ b/llama-index-core/llama_index/core/workflow/context.py @@ -1,21 +1,91 @@ from collections import defaultdict +import asyncio from typing import Dict, Any, Optional, List, Type from .events import Event class Context: + """A global object representing a context for a given workflow run. + + The Context object can be used to store data that needs to be available across iterations during a workflow + execution, and across multiple workflow runs. + Every context instance offers two type of data storage: a global one, that's shared among all the steps within a + workflow, and private one, that's only accessible from a single step. + + Both `set` and `get` operations on global data are governed by a lock, and considered coroutine-safe. + """ + def __init__(self, parent: Optional["Context"] = None) -> None: - # Global state + # Global data storage if parent: - self.data = parent.data + self._globals = parent._globals else: - self.data: Dict[str, Any] = {} + self._globals: Dict[str, Any] = {} + self._lock = asyncio.Lock() + + # Local data storage + self._locals: Dict[str, Any] = {} # Step-specific instance - self.parent = parent + self._parent: Optional[Context] = parent self._events_buffer: Dict[Type[Event], List[Event]] = defaultdict(list) + async def set(self, key: str, value: Any, make_private: bool = False) -> None: + """Store `value` into the Context under `key`. + + Args: + key: A unique string to identify the value stored. + value: The data to be stored. + make_private: Make the value only accessible from the step that stored it. + + Raises: + ValueError: When make_private is True but a key already exists in the global storage. + """ + if make_private: + if key in self._globals: + msg = f"A key named '{key}' already exists in the Context storage." + raise ValueError(msg) + self._locals[key] = value + return + + async with self.lock: + self._globals[key] = value + + async def get(self, key: str, default: Optional[Any] = None) -> Any: + """Get the value corresponding to `key` from the Context. + + Args: + key: A unique string to identify the value stored. + default: The value to return when `key` is missing instead of raising an exception. + + Raises: + ValueError: When there's not value accessible corresponding to `key`. + """ + if key in self._locals: + return self._locals[key] + elif key in self._globals: + async with self.lock: + return self._globals[key] + elif default is not None: + return default + + msg = f"Key '{key}' not found in Context" + raise ValueError(msg) + + @property + def data(self): + """This property is provided for backward compatibility. + + Use `get` and `set` instead. + """ + return self._globals + + @property + def lock(self) -> asyncio.Lock: + """Returns a mutex to lock the Context.""" + return self._parent._lock if self._parent else self._lock + def collect_events( self, ev: Event, expected: List[Type[Event]] ) -> Optional[List[Event]]: diff --git a/llama-index-core/tests/workflow/test_context.py b/llama-index-core/tests/workflow/test_context.py index 764c4538f3140..33387b8d22eb3 100644 --- a/llama-index-core/tests/workflow/test_context.py +++ b/llama-index-core/tests/workflow/test_context.py @@ -37,3 +37,46 @@ async def step3( workflow = TestWorkflow() result = await workflow.run() assert result == [ev1, ev2] + + +@pytest.mark.asyncio() +async def test_set_global(): + c1 = Context() + await c1.set(key="test_key", value=42) + + c2 = Context(parent=c1) + assert await c2.get(key="test_key") == 42 + + +@pytest.mark.asyncio() +async def test_set_private(): + c1 = Context() + await c1.set(key="test_key", value=42, make_private=True) + assert await c1.get(key="test_key") == 42 + + c2 = Context(parent=c1) + with pytest.raises(ValueError): + await c2.get(key="test_key") + + +@pytest.mark.asyncio() +async def test_set_private_duplicate(): + c1 = Context() + await c1.set(key="test_key", value=42) + + c2 = Context(parent=c1) + with pytest.raises(ValueError): + await c2.set(key="test_key", value=99, make_private=True) + + +@pytest.mark.asyncio() +async def test_get_default(): + c1 = Context() + assert await c1.get(key="test_key", default=42) == 42 + + +@pytest.mark.asyncio() +async def test_legacy_data(): + c1 = Context() + await c1.set(key="test_key", value=42) + assert c1.data["test_key"] == 42