Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 176 additions & 123 deletions hindsight-integrations/litellm/hindsight_litellm/__init__.py

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions hindsight-integrations/litellm/hindsight_litellm/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,15 @@ class HindsightCallback(CustomLogger):
- Configurable memory injection modes
- Support for entity observations in recall

NOTE: HindsightCallback and enable() are mutually exclusive injection paths.
Use one or the other — not both. Registering HindsightCallback in
litellm.callbacks while enable() is active causes double memory injection.
Prefer enable() for most use cases; use HindsightCallback directly only if
you need LiteLLM's native callback lifecycle (e.g., failure hooks).

Usage:
>>> from hindsight_litellm import configure, enable
>>> configure(bank_id="my-agent", hindsight_api_url="http://localhost:8888")
>>> configure(bank_id="my-agent")
>>> enable()
>>>
>>> # Now all LiteLLM calls will have memory integration
Expand Down Expand Up @@ -697,7 +703,7 @@ def log_pre_api_call(
# Get effective settings (kwargs override defaults)
settings = self._get_effective_settings(kwargs)
if not settings.bank_id:
raise ValueError(
raise HindsightError(
"No bank_id configured. Either call set_defaults(bank_id=...) "
"or pass hindsight_bank_id=... to the completion call."
)
Expand Down Expand Up @@ -756,7 +762,7 @@ async def async_log_pre_api_call(
# Get effective settings (kwargs override defaults)
settings = self._get_effective_settings(kwargs)
if not settings.bank_id:
raise ValueError(
raise HindsightError(
"No bank_id configured. Either call set_defaults(bank_id=...) "
"or pass hindsight_bank_id=... to the completion call."
)
Expand Down
40 changes: 39 additions & 1 deletion hindsight-integrations/litellm/hindsight_litellm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import os
import warnings
from dataclasses import asdict, dataclass, field, fields
from enum import Enum
from importlib import metadata
Expand All @@ -31,6 +32,9 @@
DEFAULT_BANK_ID = "default"
HINDSIGHT_API_KEY_ENV = "HINDSIGHT_API_KEY"

VALID_BUDGETS = frozenset({"low", "mid", "high"})
VALID_TAGS_MATCH = frozenset({"any", "all", "any_strict", "all_strict"})


class MemoryInjectionMode(str, Enum):
"""How memories should be injected into the prompt.
Expand Down Expand Up @@ -317,10 +321,22 @@ def configure(
"""
global _global_config

# Validate per-call defaults
if budget not in VALID_BUDGETS:
raise ValueError(f"budget must be one of {sorted(VALID_BUDGETS)!r}, got {budget!r}")
if recall_tags_match not in VALID_TAGS_MATCH:
raise ValueError(f"recall_tags_match must be one of {sorted(VALID_TAGS_MATCH)!r}, got {recall_tags_match!r}")
if document_id is not None:
warnings.warn(
"document_id is deprecated; use session_id instead.",
DeprecationWarning,
stacklevel=2,
)

# Apply connection-level defaults
resolved_api_url = hindsight_api_url or DEFAULT_HINDSIGHT_API_URL
resolved_api_key = api_key or os.environ.get(HINDSIGHT_API_KEY_ENV)
resolved_bank_id = bank_id or DEFAULT_BANK_ID
resolved_bank_id = bank_id

# Build default settings
default_settings = HindsightCallSettings(
Expand Down Expand Up @@ -429,6 +445,18 @@ def set_defaults(
"""
global _global_config

# Validate values when explicitly provided
if budget is not None and budget not in VALID_BUDGETS:
raise ValueError(f"budget must be one of {sorted(VALID_BUDGETS)!r}, got {budget!r}")
if recall_tags_match is not None and recall_tags_match not in VALID_TAGS_MATCH:
raise ValueError(f"recall_tags_match must be one of {sorted(VALID_TAGS_MATCH)!r}, got {recall_tags_match!r}")
if document_id is not None:
warnings.warn(
"document_id is deprecated; use session_id instead.",
DeprecationWarning,
stacklevel=2,
)

# Ensure configure() was called
if _global_config is None:
# Auto-configure with defaults if not configured
Expand Down Expand Up @@ -556,3 +584,13 @@ def reset_config() -> None:
"""Reset all global configuration to None."""
global _global_config
_global_config = None


def _restore_config(saved_config: Optional[HindsightConfig]) -> None:
"""Directly restore global config from a saved snapshot, bypassing all side effects.

Used by hindsight_memory() context manager to atomically restore state on exit
without triggering warnings, bank creation, or other configure() side effects.
"""
global _global_config
_global_config = saved_config
54 changes: 40 additions & 14 deletions hindsight-integrations/litellm/hindsight_litellm/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def recall(
budget: Optional[str] = None,
max_tokens: Optional[int] = None,
hindsight_api_url: Optional[str] = None,
include_entities: Optional[bool] = None,
trace: Optional[bool] = None,
recall_tags: Optional[List[str]] = None,
recall_tags_match: Optional[str] = None,
) -> RecallResponse:
"""Recall memories from Hindsight.

Expand All @@ -118,6 +122,10 @@ def recall(
budget: Recall budget level (low, mid, high) - controls how many memories are returned
max_tokens: Maximum tokens for memory context
hindsight_api_url: Override the configured API URL
include_entities: Include entity observations in results (default: from config)
trace: Enable trace info for debugging (default: from config)
recall_tags: Tags to filter by when recalling memories
recall_tags_match: Tag matching mode - any/all/any_strict/all_strict (default: from config)

Returns:
RecallResponse containing matched memories (iterable like a list).
Expand All @@ -128,19 +136,16 @@ def recall(

Example:
>>> from hindsight_litellm import configure, recall
>>> configure(bank_id="my-agent", hindsight_api_url="http://localhost:8888")
>>> configure(bank_id="my-agent")
>>>
>>> # Query memories
>>> memories = recall("what projects am I working on?")
>>> for m in memories:
... print(f"- [{m.fact_type}] {m.text}")
- [world] User is building a FastAPI project
>>>
>>> # With verbose mode, access debug info
>>> configure(bank_id="my-agent", verbose=True)
>>> memories = recall("what projects am I working on?")
>>> if memories.debug:
... print(f"Queried bank: {memories.debug.bank_id}")
>>> # Filter by tags
>>> memories = recall("preferences", recall_tags=["user:alice"], recall_tags_match="any_strict")
"""
# Get config and defaults, or use overrides
config = get_config()
Expand All @@ -151,6 +156,12 @@ def recall(
target_fact_types = fact_types or (defaults.fact_types if defaults else None)
target_budget = budget or (defaults.budget if defaults else "mid")
target_max_tokens = max_tokens or (defaults.max_memory_tokens if defaults else 4096)
target_include_entities = (
include_entities if include_entities is not None else (defaults.include_entities if defaults else True)
)
target_trace = trace if trace is not None else (defaults.trace if defaults else False)
target_recall_tags = recall_tags or (defaults.recall_tags if defaults else None)
target_recall_tags_match = recall_tags_match or (defaults.recall_tags_match if defaults else "any")

if not api_url or not target_bank_id:
raise RuntimeError("Hindsight not configured. Call configure() or provide bank_id and hindsight_api_url.")
Expand All @@ -161,13 +172,19 @@ def recall(
client = _get_client(api_url, config.api_key if config else None)

# Call recall API
results = client.recall(
bank_id=target_bank_id,
query=query,
types=target_fact_types,
budget=target_budget,
max_tokens=target_max_tokens,
)
recall_kwargs: dict = {
"bank_id": target_bank_id,
"query": query,
"types": target_fact_types,
"budget": target_budget,
"max_tokens": target_max_tokens,
"trace": target_trace,
"include_entities": target_include_entities,
}
if target_recall_tags:
recall_kwargs["tags"] = target_recall_tags
recall_kwargs["tags_match"] = target_recall_tags_match
results = client.recall(**recall_kwargs)

# Convert to RecallResult objects
recall_results = []
Expand Down Expand Up @@ -283,6 +300,8 @@ def reflect(
context: Optional[str] = None,
response_schema: Optional[dict] = None,
hindsight_api_url: Optional[str] = None,
recall_tags: Optional[List[str]] = None,
recall_tags_match: Optional[str] = None,
) -> ReflectResult:
"""Generate a contextual answer based on memories.

Expand Down Expand Up @@ -319,6 +338,8 @@ def reflect(
api_url = hindsight_api_url or (config.hindsight_api_url if config else None)
target_bank_id = bank_id or (defaults.bank_id if defaults else None)
target_budget = budget or (defaults.budget if defaults else "mid")
target_recall_tags = recall_tags or (defaults.recall_tags if defaults else None)
target_recall_tags_match = recall_tags_match or (defaults.recall_tags_match if defaults else "any")

if not api_url or not target_bank_id:
raise RuntimeError("Hindsight not configured. Call configure() or provide bank_id and hindsight_api_url.")
Expand All @@ -329,7 +350,7 @@ def reflect(
client = _get_client(api_url, config.api_key if config else None)

# Call reflect API
reflect_kwargs = {
reflect_kwargs: dict = {
"bank_id": target_bank_id,
"query": query,
"budget": target_budget,
Expand All @@ -338,6 +359,9 @@ def reflect(
reflect_kwargs["context"] = context
if response_schema is not None:
reflect_kwargs["response_schema"] = response_schema
if target_recall_tags:
reflect_kwargs["tags"] = target_recall_tags
reflect_kwargs["tags_match"] = target_recall_tags_match
result = client.reflect(**reflect_kwargs)

# Convert to ReflectResult
Expand Down Expand Up @@ -651,6 +675,7 @@ async def aretain(
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, str]] = None,
hindsight_api_url: Optional[str] = None,
sync: bool = False,
) -> RetainResult:
"""Async version of retain().

Expand All @@ -669,6 +694,7 @@ async def aretain(
tags=tags,
metadata=metadata,
hindsight_api_url=hindsight_api_url,
sync=sync,
),
)

Expand Down
10 changes: 5 additions & 5 deletions hindsight-integrations/litellm/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def teardown_method(self):
reset_config()

def test_configure_with_no_arguments(self):
"""Test configure() with no arguments uses defaults."""
"""Test configure() with no arguments uses production URL and leaves bank_id unset."""
config = configure()

assert config.hindsight_api_url == DEFAULT_HINDSIGHT_API_URL
assert config.bank_id == DEFAULT_BANK_ID
assert config.bank_id is None

def test_configure_reads_api_key_from_env(self):
"""Test configure() reads API key from environment variable."""
Expand Down Expand Up @@ -81,10 +81,10 @@ def test_configure_explicit_values_override_defaults(self):
assert config.bank_id == "custom-bank"
assert config.api_key == "custom-key"

def test_is_configured_true_with_defaults(self):
"""Test is_configured() returns True with default config."""
def test_is_configured_false_without_explicit_bank_id(self):
"""Test is_configured() returns False when configure() is called without bank_id."""
configure()
assert is_configured() is True
assert is_configured() is False

def test_is_configured_false_when_not_configured(self):
"""Test is_configured() returns False when not configured."""
Expand Down
Loading