From 9154bc666682c25d4bfc9702661f8ead323dc887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Du=C3=A9=C3=B1ez-Guzm=C3=A1n?= Date: Sun, 7 Jul 2024 07:52:59 -0700 Subject: [PATCH] Add a component class that allows access to its `pre_act` context during the `PRE_ACT` phase. This is achieved by making the context ignore the action spec, and computing its context only once during the first time it is requested. Components that want their `pre_act` context available to other components should derive from this class. PiperOrigin-RevId: 650012324 Change-Id: Ic63aaa84a997fa6ddd42ce01567e1bb8d2a697ab --- concordia/components/agent/v2/__init__.py | 1 + .../agent/v2/action_spec_ignored.py | 63 +++++++++++++++++++ concordia/components/agent/v2/observation.py | 16 +++-- 3 files changed, 71 insertions(+), 9 deletions(-) create mode 100644 concordia/components/agent/v2/action_spec_ignored.py diff --git a/concordia/components/agent/v2/__init__.py b/concordia/components/agent/v2/__init__.py index da838ca..96a58df 100644 --- a/concordia/components/agent/v2/__init__.py +++ b/concordia/components/agent/v2/__init__.py @@ -14,6 +14,7 @@ """Library of components specifically for generative agents.""" +from concordia.components.agent.v2 import action_spec_ignored from concordia.components.agent.v2 import constant from concordia.components.agent.v2 import legacy_act_component from concordia.components.agent.v2 import no_op_context_processor diff --git a/concordia/components/agent/v2/action_spec_ignored.py b/concordia/components/agent/v2/action_spec_ignored.py new file mode 100644 index 0000000..2743eb6 --- /dev/null +++ b/concordia/components/agent/v2/action_spec_ignored.py @@ -0,0 +1,63 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A component that ignores the action spec in the `pre_act` method.""" + +import abc + +from concordia.typing import component_v2 +from concordia.typing import entity as entity_lib + + +class ActionSpecIgnored(component_v2.EntityComponent, metaclass=abc.ABCMeta): + """A component that ignores the action spec in the `pre_act` method. + + As a consequence, its `pre_act` state can be accessed safely by other + components. This is useful for components that need to condition their + `pre_act` state on the state of other components. Derived classes should + implement `make_pre_act_context` instead of `pre_act`. The pre_act context + will be cached and returned by `get_pre_act_context` and `pre_act`, and + cleaned up by `update`. + """ + + def __init__(self): + """Initializes the component.""" + self._pre_act_context: str | None = None + + @abc.abstractmethod + def make_pre_act_context(self) -> str: + """Creates the pre-act context.""" + raise NotImplementedError() + + def set_pre_act_context(self, pre_act_context: str) -> None: + """Creates the pre-act context.""" + if self._pre_act_context is not None: + raise ValueError('pre_act_context is already set.') + self._pre_act_context = pre_act_context + + def get_pre_act_context(self) -> str: + """Creates the pre-act context.""" + if self._pre_act_context is None: + self._pre_act_context = self.make_pre_act_context() + return self._pre_act_context + + def pre_act( + self, + action_spec: entity_lib.ActionSpec, + ) -> str: + del action_spec + return self.get_pre_act_context() + + def update(self) -> None: + self._pre_act_context = None diff --git a/concordia/components/agent/v2/observation.py b/concordia/components/agent/v2/observation.py index 65a718b..0b582cb 100644 --- a/concordia/components/agent/v2/observation.py +++ b/concordia/components/agent/v2/observation.py @@ -17,13 +17,13 @@ from collections.abc import Callable import datetime from concordia.associative_memory import associative_memory -from concordia.typing import component_v2 -from concordia.typing import entity as entity_lib +from concordia.components.agent.v2 import action_spec_ignored +import overrides -class Observation(component_v2.EntityComponent): - """A simple component to receive observations. - """ + +class Observation(action_spec_ignored.ActionSpecIgnored): + """A simple component to receive observations.""" def __init__( self, @@ -45,10 +45,8 @@ def pre_observe( ) return '' - def pre_act( - self, - unused_action_spec: entity_lib.ActionSpec, - ) -> str: + @overrides.overrides + def make_pre_act_context(self) -> str: mems = self._memory.retrieve_time_interval( self._clock_now() - self._timeframe, self._clock_now(), add_time=True )