Skip to content

Commit

Permalink
feat(agents-api): Add if-else step
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <diwank@julep.ai>
  • Loading branch information
Diwank Tomer committed Aug 19, 2024
1 parent 2c6dada commit 9fdafaf
Show file tree
Hide file tree
Showing 19 changed files with 370 additions and 964 deletions.
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ async def evaluate_step(context: StepContext) -> StepOutcome:
result = StepOutcome(output=output)
return result

except Exception as e:
except BaseException as e:
logging.error(f"Error in evaluate_step: {e}")
return StepOutcome(output=None)
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported evaluate_step directly
Expand Down
5 changes: 3 additions & 2 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ async def if_else_step(context: StepContext) -> StepOutcome:

expr: str = context.current_step.if_
output = simple_eval(expr, names=context.model_dump())
output: bool = bool(output)

result = StepOutcome(output=output)
return result

except Exception as e:
except BaseException as e:
logging.error(f"Error in if_else_step: {e}")
return StepOutcome(output=None)
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported if_else_step directly
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ async def log_step(context: StepContext) -> StepOutcome:
result = StepOutcome(output=output)
return result

except Exception as e:
except BaseException as e:
logging.error(f"Error in log_step: {e}")
return StepOutcome(output=None)
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported log_step directly
Expand Down
12 changes: 3 additions & 9 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,17 @@
from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import (
InputChatMLMessage,
PromptStep,
)
from ...autogen.openapi_model import InputChatMLMessage
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
)
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.utils.template import render_template


@activity.defn
@beartype
async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome:
async def prompt_step(context: StepContext) -> StepOutcome:
# Get context data
context_data: dict = context.model_dump()

Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ async def return_step(context: StepContext) -> StepOutcome:
result = StepOutcome(output=output)
return result

except Exception as e:
except BaseException as e:
logging.error(f"Error in log_step: {e}")
return StepOutcome(output=None)
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported return_step directly
Expand Down
19 changes: 5 additions & 14 deletions agents-api/agents_api/activities/task_steps/wait_for_input_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
Expand All @@ -9,20 +7,13 @@


async def wait_for_input_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for returning immediately, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, WaitForInputStep)

exprs = context.current_step.wait_for_input
output = simple_eval_dict(exprs, values=context.model_dump())
assert isinstance(context.current_step, WaitForInputStep)

result = StepOutcome(output=output)
return result
exprs = context.current_step.wait_for_input
output = simple_eval_dict(exprs, values=context.model_dump())

except Exception as e:
logging.error(f"Error in log_step: {e}")
return StepOutcome(output=None)
result = StepOutcome(output=output)
return result


# Note: This is here just for clarity. We could have just imported wait_for_input_step directly
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ async def yield_step(context: StepContext) -> StepOutcome:

return StepOutcome(output=arguments, transition_to=("step", transition_target))

except Exception as e:
except BaseException as e:
logging.error(f"Error in log_step: {e}")
return StepOutcome(output=None)
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported yield_step directly
Expand Down
11 changes: 9 additions & 2 deletions agents-api/agents_api/autogen/Executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,16 @@ class TransitionTarget(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
workflow: Annotated[str, Field(pattern="^[^\\W0-9]\\w*$")]
workflow: Annotated[
str,
Field(
pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$"
),
]
"""
Valid python identifier names
For Unicode character safety
See: https://unicode.org/reports/tr31/
See: https://www.unicode.org/reports/tr39/#Identifier_Characters
"""
step: int

Expand Down
25 changes: 14 additions & 11 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any, Generic, TypeVar
from typing import Annotated, Any, Type
from uuid import UUID

from pydantic import BaseModel, computed_field
from pydantic import BaseModel, Field, computed_field
from pydantic_partial import create_partial_model

from ...autogen.openapi_model import (
Agent,
CreateTaskRequest,
CreateTransitionRequest,
Execution,
PartialTaskSpecDef,
PatchTaskRequest,
Expand Down Expand Up @@ -53,6 +55,9 @@
}


PendingTransition: Type[BaseModel] = create_partial_model(CreateTransitionRequest)


class ExecutionInput(BaseModel):
developer_id: UUID
execution: Execution
Expand All @@ -66,39 +71,36 @@ class ExecutionInput(BaseModel):
session: Session | None = None


WorkflowStepType = TypeVar("WorkflowStepType", bound=WorkflowStep)


class StepContext(BaseModel, Generic[WorkflowStepType]):
class StepContext(BaseModel):
execution_input: ExecutionInput
inputs: list[dict[str, Any]]
cursor: TransitionTarget

@computed_field
@property
def outputs(self) -> list[dict[str, Any]]:
def outputs(self) -> Annotated[list[dict[str, Any]], Field(exclude=True)]:
return self.inputs[1:]

@computed_field
@property
def current_input(self) -> dict[str, Any]:
def current_input(self) -> Annotated[dict[str, Any], Field(exclude=True)]:
return self.inputs[-1]

@computed_field
@property
def current_workflow(self) -> Workflow:
def current_workflow(self) -> Annotated[Workflow, Field(exclude=True)]:
workflows: list[Workflow] = self.execution_input.task.workflows
return next(wf for wf in workflows if wf.name == self.cursor.workflow)

@computed_field
@property
def current_step(self) -> WorkflowStepType:
def current_step(self) -> Annotated[WorkflowStep, Field(exclude=True)]:
step = self.current_workflow.steps[self.cursor.step]
return step

@computed_field
@property
def is_last_step(self) -> bool:
def is_last_step(self) -> Annotated[bool, Field(exclude=True)]:
return (self.cursor.step + 1) == len(self.current_workflow.steps)

def model_dump(self, *args, **kwargs) -> dict[str, Any]:
Expand All @@ -109,6 +111,7 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]:


class StepOutcome(BaseModel):
error: str | None = None
output: Any
transition_to: tuple[TransitionType, TransitionTarget] | None = None

Expand Down
Loading

0 comments on commit 9fdafaf

Please sign in to comment.