Skip to content

Commit

Permalink
Separate the component that writes observations to memory and represe…
Browse files Browse the repository at this point in the history
…nts the memories of observations in pre_act.

PiperOrigin-RevId: 722687945
Change-Id: Iaac59acdc04376cf353628ef9c73010b1b316372
  • Loading branch information
vezhnick authored and copybara-github committed Feb 3, 2025
1 parent 1b21ed8 commit 01a8c7d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 12 deletions.
45 changes: 35 additions & 10 deletions concordia/components/agent/unstable/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,25 @@
DEFAULT_OBSERVATION_PRE_ACT_KEY = 'Observation'


class Observation(action_spec_ignored.ActionSpecIgnored):
"""A simple component to receive observations."""
class ObservationToMemory(action_spec_ignored.ActionSpecIgnored):
"""A component that adds observations to the memory."""

def __init__(
self,
history_length: int,
memory_component_name: str = (
memory_component.DEFAULT_MEMORY_COMPONENT_NAME
),
pre_act_key: str = DEFAULT_OBSERVATION_PRE_ACT_KEY,
logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
):
"""Initializes the observation component.
Args:
history_length: The maximum number of observations to retrieve in
`pre_act`.
memory_component_name: Name of the memory component to add observations to
in `pre_observe` and to retrieve observations from in `pre_act`.
pre_act_key: Prefix to add to the output of the component when called in
`pre_act`.
logging_channel: The channel to use for debug logging.
"""
super().__init__(pre_act_key)
super().__init__('')
self._memory_component_name = memory_component_name
self._history_length = history_length
self._logging_channel = logging_channel

def pre_observe(
Expand All @@ -60,6 +53,38 @@ def pre_observe(
memory.add(f'[observation] {observation}')
return ''

def _make_pre_act_value(self) -> str:
return ''


class LastNObservations(action_spec_ignored.ActionSpecIgnored):
"""A simple component to receive observations."""

def __init__(
self,
history_length: int,
memory_component_name: str = (
memory_component.DEFAULT_MEMORY_COMPONENT_NAME
),
pre_act_key: str = DEFAULT_OBSERVATION_PRE_ACT_KEY,
logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
):
"""Initializes the observation component.
Args:
history_length: The maximum number of observations to retrieve in
`pre_act`.
memory_component_name: Name of the memory component to add observations to
in `pre_observe` and to retrieve observations from in `pre_act`.
pre_act_key: Prefix to add to the output of the component when called in
`pre_act`.
logging_channel: The channel to use for debug logging.
"""
super().__init__(pre_act_key)
self._memory_component_name = memory_component_name
self._history_length = history_length
self._logging_channel = logging_channel

def _make_pre_act_value(self) -> str:
"""Returns the latest observations to preact."""
memory = self.get_entity().get_component(
Expand Down
9 changes: 8 additions & 1 deletion concordia/factory/agent/unstable/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,14 @@ def build_agent(
logging_channel=measurements.get_channel('TimeDisplay').on_next,
)

observation_to_memory = agent_components_v2.observation.ObservationToMemory(
logging_channel=measurements.get_channel(
'ObservationsSinceLastUpdate'
).on_next,
)

observation_label = '\nObservation'
observation = agent_components_v2.observation.Observation(
observation = agent_components_v2.observation.LastNObservations(
history_length=100,
pre_act_key=observation_label,
logging_channel=measurements.get_channel(
Expand Down Expand Up @@ -154,6 +160,7 @@ def build_agent(
relevant_memories,
self_perception,
situation_representation,
observation_to_memory,
observation,
person_by_situation,
)
Expand Down
7 changes: 6 additions & 1 deletion concordia/factory/agent/unstable/minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@ def build_agent(
else:
time_display = None

observation_to_memory = agent_components_v2.observation.ObservationToMemory(
logging_channel=measurements.get_channel('Observation').on_next,
)

observation_label = '\nObservation'
observation = agent_components_v2.observation.Observation(
observation = agent_components_v2.observation.LastNObservations(
history_length=100,
pre_act_key=observation_label,
logging_channel=measurements.get_channel('Observation').on_next,
Expand All @@ -96,6 +100,7 @@ def build_agent(
entity_components = (
# Components that provide pre_act context.
instructions,
observation_to_memory,
observation,
)
components_of_agent = {
Expand Down

0 comments on commit 01a8c7d

Please sign in to comment.