diff --git a/src/google/adk/sessions/state.py b/src/google/adk/sessions/state.py index 1cb3c5820..7167c2b62 100644 --- a/src/google/adk/sessions/state.py +++ b/src/google/adk/sessions/state.py @@ -14,6 +14,8 @@ from typing import Any +_sentinel = object() + class State: """A state dict that maintain the current value and the pending-commit delta.""" @@ -37,6 +39,15 @@ def __getitem__(self, key: str) -> Any: return self._delta[key] return self._value[key] + def __delitem__(self, key: str): + """Deletes the value of the state dict for the given key""" + if key not in self: + raise KeyError(key) + if key in self._delta: + del self._delta[key] + if key in self._value: + del self._value[key] + def __setitem__(self, key: str, value: Any): """Sets the value of the state dict for the given key.""" # TODO: make new change only store in delta, so that self._value is only @@ -58,6 +69,18 @@ def get(self, key: str, default: Any = None) -> Any: return default return self[key] + def pop(self, key: str, default: Any = _sentinel) -> Any: + """Deletes the value of the state dict for the given key""" + if key in self: + value_to_return = self[key] + del self[key] + return value_to_return + else: + if default is not _sentinel: + return default + else: + raise KeyError(key) + def update(self, delta: dict[str, Any]): """Updates the state dict with the given delta.""" self._value.update(delta)