Skip to content

Commit

Permalink
Improve API for components set_entity and get_entity so it return…
Browse files Browse the repository at this point in the history
…s a `ComponentEntity` that has `get_phase` and `get_component`.

PiperOrigin-RevId: 650015543
Change-Id: I9102cb1e25b82dbbffc1e654ee5632365e64487c
  • Loading branch information
duenez authored and copybara-github committed Jul 7, 2024
1 parent 67063fa commit 26fb084
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 48 deletions.
59 changes: 15 additions & 44 deletions concordia/agents/entity_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""A modular entity agent using the new component system."""

from collections.abc import Mapping
import enum
import functools
import types

Expand All @@ -24,41 +23,13 @@
from concordia.typing import entity
from concordia.utils import concurrency

from typing_extensions import override
import overrides


_EMPTY_MAPPING = types.MappingProxyType({})


class Phase(enum.Enum):
"""Phases of the agent lifecycle.
Attributes:
INIT: The agent has just been created. No action has been requested nor
observation has been received. This can be followed by a call to `pre_act`
or `pre_observe`.
PRE_ACT: The agent has received a request to act. Components are being
requested for their action context. This will be followed by `POST_ACT`.
POST_ACT: The agent has just submitted an action attempt. Components are
being informed of the action attempt. This will be followed by
`UPDATE`.
PRE_OBSERVE: The agent has received an observation. Components are being
informed of the observation. This will be followed by `POST_OBSERVE`.
POST_OBSERVE: The agent has just observed. Components are given a chance to
provide context after processing the observation. This will be followed by
`UPDATE`.
UPDATE: The agent is about to update its internal state. This will be
followed by `PRE_ACT` or `PRE_OBSERVE`.
"""
INIT = enum.auto()
PRE_ACT = enum.auto()
POST_ACT = enum.auto()
PRE_OBSERVE = enum.auto()
POST_OBSERVE = enum.auto()
UPDATE = enum.auto()


class EntityAgent(entity.Entity):
class EntityAgent(component_v2.ComponentEntity):
"""An agent that has its functionality defined by components.
The agent has a set of components that define its functionality. The agent
Expand Down Expand Up @@ -88,6 +59,7 @@ def __init__(
components: The components that will be used by the agent.
"""
self._agent_name = agent_name
self._phase = component_v2.Phase.INIT

self._act_component = act_component
self._act_component.set_entity(self)
Expand All @@ -102,14 +74,12 @@ def __init__(
for component in self._components.values():
component.set_entity(self)

self._phase = Phase.INIT

@functools.cached_property
@override
@overrides.overrides
def name(self) -> str:
return self._agent_name

def get_phase(self) -> Phase:
def get_phase(self) -> component_v2.Phase:
"""Returns the current phase of the agent."""
return self._phase

Expand Down Expand Up @@ -145,37 +115,38 @@ def _parallel_call_(
name: future.result() for name, future in context_futures.items()
}

@override
def act(self, action_spec=entity.DEFAULT_ACTION_SPEC) -> str:
self._phase = Phase.PRE_ACT
@overrides.overrides
def act(self,
action_spec: entity.ActionSpec = entity.DEFAULT_ACTION_SPEC) -> str:
self._phase = component_v2.Phase.PRE_ACT
contexts = self._parallel_call_('pre_act', action_spec)

action_attempt = self._act_component.get_action_attempt(
contexts, action_spec)

self._phase = Phase.POST_ACT
self._phase = component_v2.Phase.POST_ACT
contexts = self._parallel_call_('post_act', action_spec)
self._context_processor.process(contexts)

self._phase = Phase.UPDATE
self._phase = component_v2.Phase.UPDATE
self._parallel_call_('update')

return action_attempt

@override
@overrides.overrides
def observe(
self,
observation: str,
) -> None:
self._phase = Phase.PRE_OBSERVE
self._phase = component_v2.Phase.PRE_OBSERVE
contexts = self._parallel_call_('pre_observe', observation)
self._context_processor.process(contexts)

self._phase = Phase.POST_OBSERVE
self._phase = component_v2.Phase.POST_OBSERVE
contexts = self._parallel_call_('post_observe')
self._context_processor.process(contexts)

self._phase = Phase.UPDATE
self._phase = component_v2.Phase.UPDATE
self._parallel_call_('update')

def get_last_log(self):
Expand Down
55 changes: 51 additions & 4 deletions concordia/typing/component_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import abc
from collections.abc import Mapping
import enum

from concordia.typing import entity as entity_lib

Expand All @@ -24,17 +25,63 @@
ComponentContextMapping = Mapping[ComponentName, ComponentContext]


class Phase(enum.Enum):
"""Phases of a component entity lifecycle.
Attributes:
INIT: The agent has just been created. No action has been requested nor
observation has been received. This can be followed by a call to `pre_act`
or `pre_observe`.
PRE_ACT: The agent has received a request to act. Components are being
requested for their action context. This will be followed by `POST_ACT`.
POST_ACT: The agent has just submitted an action attempt. Components are
being informed of the action attempt. This will be followed by
`UPDATE`.
PRE_OBSERVE: The agent has received an observation. Components are being
informed of the observation. This will be followed by `POST_OBSERVE`.
POST_OBSERVE: The agent has just observed. Components are given a chance to
provide context after processing the observation. This will be followed by
`UPDATE`.
UPDATE: The agent is about to update its internal state. This will be
followed by `PRE_ACT` or `PRE_OBSERVE`.
"""
INIT = enum.auto()
PRE_ACT = enum.auto()
POST_ACT = enum.auto()
PRE_OBSERVE = enum.auto()
POST_OBSERVE = enum.auto()
UPDATE = enum.auto()


class ComponentEntity(entity_lib.Entity):
"""An entity that contains components."""

@abc.abstractmethod
def get_phase(self) -> Phase:
"""Returns the current phase of the component entity."""
raise NotImplementedError()

@abc.abstractmethod
def get_component(self, component_name: str) -> "BaseComponent":
"""Returns the component with the given name.
Args:
component_name: The name of the component to return.
"""
raise NotImplementedError()


class BaseComponent:
"""A base class for components."""

def __init__(self):
self._entity = None
self._entity: ComponentEntity | None = None

def set_entity(self, entity: entity_lib.Entity) -> None:
def set_entity(self, entity: ComponentEntity) -> None:
"""Sets the entity that this component belongs to."""
self._entity = entity

def get_entity(self) -> entity_lib.Entity:
def get_entity(self) -> ComponentEntity:
"""Returns the entity that this component belongs to.
Raises:
Expand All @@ -46,7 +93,7 @@ def get_entity(self) -> entity_lib.Entity:


class EntityComponent(BaseComponent):
"""A building block of an entity.
"""A building block of a ComponentEntity.
Components are stand-alone pieces of functionality insterted into a GameObject
that have hooks for processing events for acting and observing.
Expand Down

0 comments on commit 26fb084

Please sign in to comment.