Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add test for evaluate step #460

Merged
merged 12 commits into from
Aug 19, 2024
4 changes: 3 additions & 1 deletion agents-api/agents_api/activities/demo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

from temporalio import activity

from ..env import testing
Expand All @@ -12,6 +14,6 @@ async def mock_demo_activity(a: int, b: int) -> int:
return a + b


demo_activity = activity.defn(name="demo_activity")(
demo_activity: Callable[[int, int], int] = activity.defn(name="demo_activity")(
demo_activity if not testing else mock_demo_activity
)
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from beartype import beartype
from temporalio import activity

from ..clients import cozo
from ..clients import embed as embedder
from ..clients.cozo import get_cozo_client
from ..env import testing
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query
from .types import EmbedDocsPayload
Expand All @@ -28,7 +28,7 @@ async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
doc_id=payload.doc_id,
snippet_indices=indices,
embeddings=embeddings,
client=cozo_client or get_cozo_client(),
client=cozo_client or cozo.get_cozo_client(),
)


Expand Down
7 changes: 4 additions & 3 deletions agents-api/agents_api/activities/logger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import TextIO

logger = logging.getLogger(__name__)
h = logging.StreamHandler()
fmt = logging.Formatter("[%(asctime)s/%(levelname)s] - %(message)s")
logger: logging.Logger = logging.getLogger(__name__)
h: logging.StreamHandler[TextIO] = logging.StreamHandler()
fmt: logging.Formatter = logging.Formatter("[%(asctime)s/%(levelname)s] - %(message)s")
h.setFormatter(fmt)
logger.addHandler(h)
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@


# TODO: remove stubs
def entries_summarization_query(*args, **kwargs):
def entries_summarization_query(*args, **kwargs) -> pd.DataFrame:
return pd.DataFrame()


def get_toplevel_entries_query(*args, **kwargs):
def get_toplevel_entries_query(*args, **kwargs) -> pd.DataFrame:
return pd.DataFrame()


Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

from .evaluate_step import evaluate_step
from .if_else_step import if_else_step
from .log_step import log_step
from .prompt_step import prompt_step
from .raise_complete_async import raise_complete_async
from .return_step import return_step
from .switch_step import switch_step
from .tool_call_step import tool_call_step
from .transition_step import transition_step
from .wait_for_input_step import wait_for_input_step
from .yield_step import yield_step
27 changes: 16 additions & 11 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from typing import Any
import logging

from beartype import beartype
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import EvaluateStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing


@beartype
async def evaluate_step(
context: StepContext[EvaluateStep],
) -> StepOutcome[dict[str, Any]]:
exprs = context.definition.arguments
output = simple_eval_dict(exprs, values=context.model_dump())
async def evaluate_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for returning immediately, so we just evaluate the expression
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The evaluate_step function signature should include the specific step type EvaluateStep for the context parameter to ensure type safety and clarity in the function's usage.

Suggested change
async def evaluate_step(context: StepContext) -> StepOutcome:
async def evaluate_step(context: StepContext[EvaluateStep]) -> StepOutcome:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The evaluate_step function is missing type annotations for its parameters. Consider adding these to enhance code clarity and error prevention.

Suggested change
async def evaluate_step(context: StepContext) -> StepOutcome:
async def evaluate_step(context: StepContext[EvaluateStep]) -> StepOutcome:

# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, EvaluateStep)

exprs = context.current_step.evaluate
output = simple_eval_dict(exprs, values=context.model_dump())

result = StepOutcome(output=output)
return result

return StepOutcome(output=output)
except BaseException as e:
logging.error(f"Error in evaluate_step: {e}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching BaseException is generally not recommended as it can catch system-exiting exceptions such as SystemExit and KeyboardInterrupt, which should normally be allowed to propagate. Consider catching Exception instead, which does not include these system-exiting exceptions.

return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported evaluate_step directly
Expand Down
42 changes: 29 additions & 13 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import (
IfElseWorkflowStep,
)
from ...autogen.openapi_model import IfElseWorkflowStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing


@activity.defn
@beartype
async def if_else_step(context: StepContext[IfElseWorkflowStep]) -> dict:
raise NotImplementedError()
# context_data: dict = context.model_dump()
async def if_else_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, IfElseWorkflowStep)

expr: str = context.current_step.if_
output = simple_eval(expr, names=context.model_dump())
output: bool = bool(output)

result = StepOutcome(output=output)
return result

# next_workflow = (
# context.definition.then
# if simple_eval(context.definition.if_, names=context_data)
# else context.definition.else_
# )
except BaseException as e:
logging.error(f"Error in if_else_step: {e}")
return StepOutcome(error=str(e))

# return {"goto_workflow": next_workflow}

# Note: This is here just for clarity. We could have just imported if_else_step directly
# They do the same thing, so we dont need to mock the if_else_step function
mock_if_else_step = if_else_step

if_else_step = activity.defn(name="if_else_step")(
if_else_step if not testing else mock_if_else_step
)
37 changes: 37 additions & 0 deletions agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import LogStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing


@beartype
async def log_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, LogStep)

expr: str = context.current_step.log
output = simple_eval(expr, names=context.model_dump())

result = StepOutcome(output=output)
return result

except BaseException as e:
logging.error(f"Error in log_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported log_step directly
# They do the same thing, so we dont need to mock the log_step function
mock_log_step = log_step

log_step = activity.defn(name="log_step")(log_step if not testing else mock_log_step)
20 changes: 7 additions & 13 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,25 @@
from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import (
InputChatMLMessage,
PromptStep,
)
from ...autogen.openapi_model import InputChatMLMessage
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
)
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.utils.template import render_template


@activity.defn
@beartype
async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome:
async def prompt_step(context: StepContext) -> StepOutcome:
# Get context data
context_data: dict = context.model_dump()

# Render template messages
prompt = (
[InputChatMLMessage(content=context.definition.prompt)]
if isinstance(context.definition.prompt, str)
else context.definition.prompt
[InputChatMLMessage(content=context.current_step.prompt)]
if isinstance(context.current_step.prompt, str)
else context.current_step.prompt
)

template_messages: list[InputChatMLMessage] = prompt
Expand All @@ -47,7 +41,7 @@ async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome:
for m in messages
]

settings: dict = context.definition.settings.model_dump()
settings: dict = context.current_step.settings.model_dump()
# Get settings and run llm
response = await litellm.acompletion(
messages=messages,
Expand Down
37 changes: 37 additions & 0 deletions agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging

from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import ReturnStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing


async def return_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for returning immediately, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, ReturnStep)

exprs: dict[str, str] = context.current_step.return_
output = simple_eval_dict(exprs, values=context.model_dump())

result = StepOutcome(output=output)
return result

except BaseException as e:
logging.error(f"Error in log_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported return_step directly
# They do the same thing, so we dont need to mock the return_step function
mock_return_step = return_step

return_step = activity.defn(name="return_step")(
return_step if not testing else mock_return_step
)
48 changes: 48 additions & 0 deletions agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import SwitchStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing


@beartype
async def switch_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, SwitchStep)

# Assume that none of the cases evaluate to truthy
output: int = -1

cases: list[str] = [c.case for c in context.current_step.switch]

for i, case in enumerate(cases):
result = simple_eval(case, names=context.model_dump())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using simple_eval for evaluating expressions can be a security risk if user input is not properly sanitized. Consider using a more secure method or ensure that the input is strictly controlled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using simple_eval for evaluating expressions can be a security risk if user input is directly evaluated. Consider using a more secure method or ensure proper sandboxing to prevent arbitrary code execution.


if result:
output = i
break

result = StepOutcome(output=output)
return result

except BaseException as e:
logging.error(f"Error in switch_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported switch_step directly
# They do the same thing, so we dont need to mock the switch_step function
mock_switch_step = switch_step

switch_step = activity.defn(name="switch_step")(
switch_step if not testing else mock_switch_step
)
6 changes: 3 additions & 3 deletions agents-api/agents_api/activities/task_steps/tool_call_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
@beartype
async def tool_call_step(context: StepContext) -> dict:
raise NotImplementedError()
# assert isinstance(context.definition, ToolCallStep)
# assert isinstance(context.current_step, ToolCallStep)

# context.definition.tool_id
# context.definition.arguments
# context.current_step.tool_id
# context.current_step.arguments
# # get tool by id
# # call tool

Expand Down
25 changes: 25 additions & 0 deletions agents-api/agents_api/activities/task_steps/wait_for_input_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import WaitForInputStep
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing


async def wait_for_input_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, WaitForInputStep)

exprs = context.current_step.wait_for_input
output = simple_eval_dict(exprs, values=context.model_dump())

result = StepOutcome(output=output)
return result


# Note: This is here just for clarity. We could have just imported wait_for_input_step directly
# They do the same thing, so we dont need to mock the wait_for_input_step function
mock_wait_for_input_step = wait_for_input_step

wait_for_input_step = activity.defn(name="wait_for_input_step")(
wait_for_input_step if not testing else mock_wait_for_input_step
)
Loading
Loading