Skip to content

Commit

Permalink
feat(agents-api): Add wait_for_input step
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <diwank@julep.ai>
  • Loading branch information
Diwank Tomer committed Aug 18, 2024
1 parent d954077 commit 6cd98ae
Show file tree
Hide file tree
Showing 20 changed files with 174 additions and 70 deletions.
1 change: 1 addition & 0 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .return_step import return_step
from .tool_call_step import tool_call_step
from .transition_step import transition_step
from .wait_for_input_step import wait_for_input_step
from .yield_step import yield_step
10 changes: 3 additions & 7 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import logging

from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import EvaluateStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing


async def evaluate_step(
context: StepContext[EvaluateStep],
) -> StepOutcome:
async def evaluate_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for returning immediately, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging

from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
Expand Down
34 changes: 34 additions & 0 deletions agents-api/agents_api/activities/task_steps/wait_for_input_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging

from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import WaitForInputStep
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing


async def wait_for_input_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for returning immediately, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, WaitForInputStep)

exprs = context.current_step.wait_for_input
output = simple_eval_dict(exprs, values=context.model_dump())

result = StepOutcome(output=output)
return result

except Exception as e:
logging.error(f"Error in log_step: {e}")
return StepOutcome(output=None)


# Note: This is here just for clarity. We could have just imported wait_for_input_step directly
# They do the same thing, so we dont need to mock the wait_for_input_step function
mock_wait_for_input_step = wait_for_input_step

wait_for_input_step = activity.defn(name="wait_for_input_step")(
wait_for_input_step if not testing else mock_wait_for_input_step
)
40 changes: 25 additions & 15 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import Callable

from beartype import beartype
from temporalio import activity

from agents_api.autogen.Executions import TransitionTarget
from agents_api.autogen.openapi_model import TransitionTarget, YieldStep

from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing
Expand All @@ -12,24 +13,33 @@

@beartype
async def yield_step(context: StepContext) -> StepOutcome:
all_workflows = context.execution_input.task.workflows
workflow = context.current_step.workflow
# NOTE: This activity is only for returning immediately, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, YieldStep)

assert workflow in [
wf.name for wf in all_workflows
], f"Workflow {workflow} not found in task"
all_workflows = context.execution_input.task.workflows
workflow = context.current_step.workflow

# Evaluate the expressions in the arguments
exprs = context.current_step.arguments
arguments = simple_eval_dict(exprs, values=context.model_dump())
assert workflow in [
wf.name for wf in all_workflows
], f"Workflow {workflow} not found in task"

# Transition to the first step of that workflow
transition_target = TransitionTarget(
workflow=workflow,
step=0,
)
# Evaluate the expressions in the arguments
exprs = context.current_step.arguments
arguments = simple_eval_dict(exprs, values=context.model_dump())

return StepOutcome(output=arguments, transition_to=("step", transition_target))
# Transition to the first step of that workflow
transition_target = TransitionTarget(
workflow=workflow,
step=0,
)

return StepOutcome(output=arguments, transition_to=("step", transition_target))

except Exception as e:
logging.error(f"Error in log_step: {e}")
return StepOutcome(output=None)


# Note: This is here just for clarity. We could have just imported yield_step directly
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ class WaitForInputStep(BaseWorkflowStep):
populate_by_name=True,
)
kind_: Literal["wait_for_input"] = "wait_for_input"
info: str | dict[str, Any]
wait_for_input: dict[str, str]
"""
Any additional info or data
"""
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def create_worker(client: Client) -> Any:
return_step,
tool_call_step,
transition_step,
wait_for_input_step,
yield_step,
)
from ..activities.truncation import truncation
Expand All @@ -46,6 +47,7 @@ def create_worker(client: Client) -> Any:
return_step,
tool_call_step,
transition_step,
wait_for_input_step,
yield_step,
]

Expand Down
37 changes: 26 additions & 11 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,30 @@
with workflow.unsafe.imports_passed_through():
from ..activities.task_steps import (
evaluate_step,
if_else_step,
# if_else_step,
log_step,
prompt_step,
# prompt_step,
raise_complete_async,
return_step,
tool_call_step,
# tool_call_step,
transition_step,
wait_for_input_step,
yield_step,
)
from ..autogen.openapi_model import (
CreateTransitionRequest,
ErrorWorkflowStep,
EvaluateStep,
IfElseWorkflowStep,
# IfElseWorkflowStep,
LogStep,
PromptStep,
# PromptStep,
ReturnStep,
SleepFor,
SleepStep,
ToolCallStep,
# ToolCallStep,
TransitionTarget,
TransitionType,
# WaitForInputStep,
WaitForInputStep,
# WorkflowStep,
YieldStep,
)
Expand All @@ -45,16 +47,17 @@


STEP_TO_ACTIVITY = {
PromptStep: prompt_step,
ToolCallStep: tool_call_step,
YieldStep: yield_step,
# PromptStep: prompt_step,
# ToolCallStep: tool_call_step,
WaitForInputStep: wait_for_input_step,
}

STEP_TO_LOCAL_ACTIVITY = {
# NOTE: local activities are directly called in the workflow executor
# They MUST NOT FAIL, otherwise they will crash the workflow
EvaluateStep: evaluate_step,
IfElseWorkflowStep: if_else_step,
# IfElseWorkflowStep: if_else_step,
YieldStep: yield_step,
LogStep: log_step,
ReturnStep: return_step,
}
Expand Down Expand Up @@ -179,6 +182,9 @@ async def transition(**kwargs):
case YieldStep(), StepOutcome(
output=output, transition_to=(yield_transition_type, yield_next_target)
):
if output is None:
raise ApplicationError("yield step threw an error")

await transition(
output=output, type=yield_transition_type, next=yield_next_target
)
Expand All @@ -190,6 +196,15 @@ async def transition(**kwargs):

final_output = yield_outcome

case WaitForInputStep(), StepOutcome(output=output):
await transition(output=output, type="wait", next=None)

transition_type = "resume"
final_output = await execute_activity(
raise_complete_async,
schedule_to_close_timeout=timedelta(days=31),
)

case _:
raise NotImplementedError()

Expand Down
68 changes: 67 additions & 1 deletion agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Tests for task queries

from ward import test, raises
import asyncio

from google.protobuf.json_format import MessageToDict
from ward import raises, test

from agents_api.autogen.openapi_model import CreateExecutionRequest, CreateTaskRequest
from agents_api.models.task.create_task import create_task
Expand Down Expand Up @@ -382,3 +385,66 @@ async def _(

result = await handle.result()
assert result["hello"] == data.input["test"]


@test("workflow: wait for input step start")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(
**{
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [
{"wait_for_input": {"hi": '"bye"'}},
],
}
),
client=client,
)

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
client=client,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input
mock_run_task_execution_workflow.assert_called_once()

# Let it run for a bit
await asyncio.sleep(1)

# Get the history
history = await handle.fetch_history()
events = [MessageToDict(e) for e in history.events]
assert len(events) > 0

activities_scheduled = [
event.get("activityTaskScheduledEventAttributes", {})
.get("activityType", {})
.get("name")
for event in events
if "ACTIVITY_TASK_SCHEDULED" in event["eventType"]
]
activities_scheduled = [
activity for activity in activities_scheduled if activity
]

assert activities_scheduled == [
"wait_for_input_step",
"transition_step",
"raise_complete_async",
]
2 changes: 0 additions & 2 deletions sdks/python/julep/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@
TasksUpdateTaskRequestMainItem_WaitForInput,
TasksUpdateTaskRequestMainItem_Yield,
TasksWaitForInputStep,
TasksWaitForInputStepInfo,
TasksYieldStep,
ToolsChosenFunctionCall,
ToolsChosenToolCall,
Expand Down Expand Up @@ -506,7 +505,6 @@
"TasksUpdateTaskRequestMainItem_WaitForInput",
"TasksUpdateTaskRequestMainItem_Yield",
"TasksWaitForInputStep",
"TasksWaitForInputStepInfo",
"TasksYieldStep",
"ToolsChosenFunctionCall",
"ToolsChosenToolCall",
Expand Down
2 changes: 0 additions & 2 deletions sdks/python/julep/api/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@
TasksUpdateTaskRequestMainItem_Yield,
)
from .tasks_wait_for_input_step import TasksWaitForInputStep
from .tasks_wait_for_input_step_info import TasksWaitForInputStepInfo
from .tasks_yield_step import TasksYieldStep
from .tools_chosen_function_call import ToolsChosenFunctionCall
from .tools_chosen_tool_call import ToolsChosenToolCall, ToolsChosenToolCall_Function
Expand Down Expand Up @@ -545,7 +544,6 @@
"TasksUpdateTaskRequestMainItem_WaitForInput",
"TasksUpdateTaskRequestMainItem_Yield",
"TasksWaitForInputStep",
"TasksWaitForInputStepInfo",
"TasksYieldStep",
"ToolsChosenFunctionCall",
"ToolsChosenToolCall",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from .tasks_search_step_search import TasksSearchStepSearch
from .tasks_set_step_set import TasksSetStepSet
from .tasks_sleep_for import TasksSleepFor
from .tasks_wait_for_input_step_info import TasksWaitForInputStepInfo


class TasksCreateTaskRequestMainItem_Evaluate(pydantic_v1.BaseModel):
Expand Down Expand Up @@ -488,7 +487,7 @@ class Config:


class TasksCreateTaskRequestMainItem_WaitForInput(pydantic_v1.BaseModel):
info: TasksWaitForInputStepInfo
wait_for_input: typing.Dict[str, CommonPyExpression]
kind: typing.Literal["wait_for_input"] = pydantic_v1.Field(
alias="kind_", default="wait_for_input"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from .tasks_search_step_search import TasksSearchStepSearch
from .tasks_set_step_set import TasksSetStepSet
from .tasks_sleep_for import TasksSleepFor
from .tasks_wait_for_input_step_info import TasksWaitForInputStepInfo


class TasksPatchTaskRequestMainItem_Evaluate(pydantic_v1.BaseModel):
Expand Down Expand Up @@ -488,7 +487,7 @@ class Config:


class TasksPatchTaskRequestMainItem_WaitForInput(pydantic_v1.BaseModel):
info: TasksWaitForInputStepInfo
wait_for_input: typing.Dict[str, CommonPyExpression]
kind: typing.Literal["wait_for_input"] = pydantic_v1.Field(
alias="kind_", default="wait_for_input"
)
Expand Down
Loading

0 comments on commit 6cd98ae

Please sign in to comment.