Skip to content

Commit

Permalink
Add get/set API to the Context and make it coroutine-safe (#15152)
Browse files Browse the repository at this point in the history
* make the Context coroutine-safe

* remove parent property

* change API

* docs

* use context manager and add unit tests
  • Loading branch information
masci authored and nerdai committed Aug 17, 2024
1 parent a233304 commit 6c9ccdf
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 4 deletions.
78 changes: 74 additions & 4 deletions llama-index-core/llama_index/core/workflow/context.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,91 @@
from collections import defaultdict
import asyncio
from typing import Dict, Any, Optional, List, Type

from .events import Event


class Context:
"""A global object representing a context for a given workflow run.
The Context object can be used to store data that needs to be available across iterations during a workflow
execution, and across multiple workflow runs.
Every context instance offers two type of data storage: a global one, that's shared among all the steps within a
workflow, and private one, that's only accessible from a single step.
Both `set` and `get` operations on global data are governed by a lock, and considered coroutine-safe.
"""

def __init__(self, parent: Optional["Context"] = None) -> None:
# Global state
# Global data storage
if parent:
self.data = parent.data
self._globals = parent._globals
else:
self.data: Dict[str, Any] = {}
self._globals: Dict[str, Any] = {}
self._lock = asyncio.Lock()

# Local data storage
self._locals: Dict[str, Any] = {}

# Step-specific instance
self.parent = parent
self._parent: Optional[Context] = parent
self._events_buffer: Dict[Type[Event], List[Event]] = defaultdict(list)

async def set(self, key: str, value: Any, make_private: bool = False) -> None:
"""Store `value` into the Context under `key`.
Args:
key: A unique string to identify the value stored.
value: The data to be stored.
make_private: Make the value only accessible from the step that stored it.
Raises:
ValueError: When make_private is True but a key already exists in the global storage.
"""
if make_private:
if key in self._globals:
msg = f"A key named '{key}' already exists in the Context storage."
raise ValueError(msg)
self._locals[key] = value
return

async with self.lock:
self._globals[key] = value

async def get(self, key: str, default: Optional[Any] = None) -> Any:
"""Get the value corresponding to `key` from the Context.
Args:
key: A unique string to identify the value stored.
default: The value to return when `key` is missing instead of raising an exception.
Raises:
ValueError: When there's not value accessible corresponding to `key`.
"""
if key in self._locals:
return self._locals[key]
elif key in self._globals:
async with self.lock:
return self._globals[key]
elif default is not None:
return default

msg = f"Key '{key}' not found in Context"
raise ValueError(msg)

@property
def data(self):
"""This property is provided for backward compatibility.
Use `get` and `set` instead.
"""
return self._globals

@property
def lock(self) -> asyncio.Lock:
"""Returns a mutex to lock the Context."""
return self._parent._lock if self._parent else self._lock

def collect_events(
self, ev: Event, expected: List[Type[Event]]
) -> Optional[List[Event]]:
Expand Down
43 changes: 43 additions & 0 deletions llama-index-core/tests/workflow/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,46 @@ async def step3(
workflow = TestWorkflow()
result = await workflow.run()
assert result == [ev1, ev2]


@pytest.mark.asyncio()
async def test_set_global():
c1 = Context()
await c1.set(key="test_key", value=42)

c2 = Context(parent=c1)
assert await c2.get(key="test_key") == 42


@pytest.mark.asyncio()
async def test_set_private():
c1 = Context()
await c1.set(key="test_key", value=42, make_private=True)
assert await c1.get(key="test_key") == 42

c2 = Context(parent=c1)
with pytest.raises(ValueError):
await c2.get(key="test_key")


@pytest.mark.asyncio()
async def test_set_private_duplicate():
c1 = Context()
await c1.set(key="test_key", value=42)

c2 = Context(parent=c1)
with pytest.raises(ValueError):
await c2.set(key="test_key", value=99, make_private=True)


@pytest.mark.asyncio()
async def test_get_default():
c1 = Context()
assert await c1.get(key="test_key", default=42) == 42


@pytest.mark.asyncio()
async def test_legacy_data():
c1 = Context()
await c1.set(key="test_key", value=42)
assert c1.data["test_key"] == 42

0 comments on commit 6c9ccdf

Please sign in to comment.