Skip to content

Commit

Permalink
Add a simple act component that behaves like the legacy basic_agent d…
Browse files Browse the repository at this point in the history
…oes (concatenating context from components).

PiperOrigin-RevId: 648661502
Change-Id: Id4b0dddaa698468e4f3ec1c110312777e2693db1
  • Loading branch information
duenez authored and Copybara-Service committed Jul 2, 2024
1 parent 05f79ab commit bf553b7
Showing 1 changed file with 112 additions and 0 deletions.
112 changes: 112 additions & 0 deletions concordia/components/agent/v2/simple_act_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A simple acting component that aggregates contexts from components."""


from collections.abc import Sequence

from concordia.language_model import language_model
from concordia.typing import component_v2
from concordia.typing import entity as entity_lib
from typing_extensions import override


class SimpleActComponent(component_v2.ActingComponent):
"""A simple acting component that aggregates contexts from components.
This component will receive the contexts from `pre_act` from all the
components, and assemble them in the order specified to `__init__`. If the
component order is not specified, then components will be assembled in the
iteration order of the `ComponentsContext` passed to `get_action_attempt`.
"""

def __init__(
self,
model: language_model.LanguageModel,
component_order: Sequence[str] | None = None,
):
"""Initializes the agent.
Args:
model: The language model to use for generating the action attempt.
component_order: The order in which the component contexts will be
assembled when calling the act component. If None, the contexts will be
assembled in the iteration order of the `ComponentsContext` passed to
`get_action_attempt`. If the component order is specified, but does not
contain all the components passed to `get_action_attempt`, the missing
components will be appended at the end in the iteration order of the
`ComponentsContext` passed to `get_action_attempt`. The same component
cannot appear twice in the component order. All components in the
component order must be in the `ComponentsContext` passed to
`get_action_attempt`.
Raises:
ValueError: If the component order is not None and contains duplicate
components.
"""
self._model = model
if component_order is None:
self._component_order = None
else:
self._component_order = tuple(component_order)
if self._component_order is not None:
if len(set(self._component_order)) != len(self._component_order):
raise ValueError(
"The component order contains duplicate components: "
+ ", ".join(self._component_order)
)

def _context_for_action(
self,
contexts: component_v2.ComponentsContext,
) -> str:
if self._component_order is None:
return "\n".join(
f"{name}: {context}" for name, context in contexts.items()
)
else:
order = self._component_order + tuple(
set(contexts.keys()) - set(self._component_order)
)
return "\n".join(
f"{name}: {contexts[name]}" for name in order
)

@override
def get_action_attempt(
self,
contexts: component_v2.ComponentsContext,
action_spec: entity_lib.ActionSpec,
) -> str:
context = self._context_for_action(contexts)
if action_spec.output_type == entity_lib.OutputType.CHOICE:
_, response, _ = self._model.sample_choice(
f"{context}\n\n{action_spec.call_to_action}\n",
action_spec.options)
return response
sampled_text = self._model.sample_text(
f"{context}\n\n{action_spec.call_to_action}\n",
)
if action_spec.output_type == entity_lib.OutputType.FREE:
return sampled_text
elif action_spec.output_type == entity_lib.OutputType.FLOAT:
try:
return str(float(sampled_text))
except ValueError:
return "0.0"
raise NotImplementedError(
f"Unsupported output type: {action_spec.output_type}. "
"Supported output types are: FREE, CHOICE, and FLOAT."
)

0 comments on commit bf553b7

Please sign in to comment.