From 2def2f004a90a5a4cdec10626fb252f4f18f8e54 Mon Sep 17 00:00:00 2001 From: DK09876 Date: Thu, 21 May 2026 15:22:22 -0700 Subject: [PATCH 1/4] feat(litellm): expand recall/reflect/hindsight_memory APIs and fix default URL - recall(): add include_entities, trace, recall_tags, recall_tags_match params (previously only supported via the callback/enable() path, not the manual API) - reflect(): add recall_tags, recall_tags_match params (same gap) - hindsight_memory(): default URL now matches configure()/wrap_openai()/wrap_anthropic() instead of hardcoding localhost; add session_id, use_reflect, reflect_context, tags, recall_tags, recall_tags_match params - Document that enable() and HindsightCallback are mutually exclusive injection paths to prevent accidental double injection - Add 17 tests covering all new behaviour Co-Authored-By: Claude Sonnet 4.6 --- .../litellm/hindsight_litellm/__init__.py | 34 ++- .../litellm/hindsight_litellm/callbacks.py | 8 +- .../litellm/hindsight_litellm/wrappers.py | 50 +++-- .../litellm/tests/test_integration.py | 198 ++++++++++++++++++ hindsight-integrations/litellm/uv.lock | 2 +- 5 files changed, 273 insertions(+), 19 deletions(-) diff --git a/hindsight-integrations/litellm/hindsight_litellm/__init__.py b/hindsight-integrations/litellm/hindsight_litellm/__init__.py index 0b75365cb..b777e1590 100644 --- a/hindsight-integrations/litellm/hindsight_litellm/__init__.py +++ b/hindsight-integrations/litellm/hindsight_litellm/__init__.py @@ -629,13 +629,18 @@ def enable() -> None: memory injection fails (when inject_memories=True) or storage fails (when store_conversations=True), the error will propagate to your code. + NOTE: enable() and HindsightCallback are mutually exclusive injection paths. + Do not register HindsightCallback in litellm.callbacks while enable() is + active — memories will be injected twice (once by the monkeypatch, once by + the callback running inside the original litellm.completion). + Must be called after configure() and set_defaults(bank_id=...). Example: >>> from hindsight_litellm import configure, set_defaults, enable, HindsightError >>> import litellm >>> - >>> configure(hindsight_api_url="http://localhost:8888") + >>> configure() >>> set_defaults(bank_id="my-agent") >>> enable() >>> @@ -1364,7 +1369,7 @@ async def acompletion(*args, **kwargs): @contextmanager def hindsight_memory( - hindsight_api_url: str = "http://localhost:8888", + hindsight_api_url: Optional[str] = None, bank_id: Optional[str] = None, api_key: Optional[str] = None, store_conversations: bool = True, @@ -1375,10 +1380,16 @@ def hindsight_memory( budget: str = "mid", fact_types: Optional[List[str]] = None, document_id: Optional[str] = None, + session_id: Optional[str] = None, excluded_models: Optional[List[str]] = None, verbose: bool = False, include_entities: bool = True, trace: bool = False, + use_reflect: bool = False, + reflect_context: Optional[str] = None, + tags: Optional[List[str]] = None, + recall_tags: Optional[List[str]] = None, + recall_tags_match: str = "any", ): """Context manager for temporary Hindsight memory integration. @@ -1387,6 +1398,7 @@ def hindsight_memory( Args: hindsight_api_url: URL of the Hindsight API server + (default: https://api.hindsight.vectorize.io) bank_id: Memory bank ID for memory operations (required). For multi-user support, use different bank_ids per user (e.g., f"user-{user_id}") api_key: Optional API key for Hindsight authentication @@ -1397,11 +1409,17 @@ def hindsight_memory( max_memory_tokens: Maximum tokens for memory context budget: Budget for memory recall (low, mid, high) fact_types: List of fact types to filter (world, experience, opinion, observation) - document_id: Optional document ID for grouping conversations + document_id: Document ID for grouping conversations (deprecated, use session_id) + session_id: Session ID for grouping conversations (upsert behavior) excluded_models: List of model patterns to exclude verbose: Enable verbose logging include_entities: Include entity observations in recall (default True) trace: Enable trace info for debugging (default False) + use_reflect: Use reflect API instead of recall (default False) + reflect_context: Context for reflect reasoning + tags: Tags to apply when storing conversations + recall_tags: Tags to filter by when recalling memories + recall_tags_match: Tag matching mode - any/all/any_strict/all_strict (default "any") Example: >>> from hindsight_litellm import hindsight_memory @@ -1410,6 +1428,10 @@ def hindsight_memory( >>> with hindsight_memory(bank_id="user-123"): ... response = litellm.completion(model="gpt-4", messages=[...]) >>> # Memory integration automatically disabled after context + >>> + >>> # With tag scoping + >>> with hindsight_memory(bank_id="user-123", tags=["session:abc"], recall_tags=["session:abc"]): + ... response = litellm.completion(model="gpt-4", messages=[...]) """ # Save previous state was_enabled = is_enabled() @@ -1430,12 +1452,18 @@ def hindsight_memory( set_defaults( bank_id=bank_id, document_id=document_id, + session_id=session_id, budget=budget, fact_types=fact_types, max_memories=max_memories, max_memory_tokens=max_memory_tokens, include_entities=include_entities, trace=trace, + use_reflect=use_reflect, + reflect_context=reflect_context, + tags=tags, + recall_tags=recall_tags, + recall_tags_match=recall_tags_match, ) enable() yield diff --git a/hindsight-integrations/litellm/hindsight_litellm/callbacks.py b/hindsight-integrations/litellm/hindsight_litellm/callbacks.py index d3d844673..506a0f9d9 100644 --- a/hindsight-integrations/litellm/hindsight_litellm/callbacks.py +++ b/hindsight-integrations/litellm/hindsight_litellm/callbacks.py @@ -74,9 +74,15 @@ class HindsightCallback(CustomLogger): - Configurable memory injection modes - Support for entity observations in recall + NOTE: HindsightCallback and enable() are mutually exclusive injection paths. + Use one or the other — not both. Registering HindsightCallback in + litellm.callbacks while enable() is active causes double memory injection. + Prefer enable() for most use cases; use HindsightCallback directly only if + you need LiteLLM's native callback lifecycle (e.g., failure hooks). + Usage: >>> from hindsight_litellm import configure, enable - >>> configure(bank_id="my-agent", hindsight_api_url="http://localhost:8888") + >>> configure(bank_id="my-agent") >>> enable() >>> >>> # Now all LiteLLM calls will have memory integration diff --git a/hindsight-integrations/litellm/hindsight_litellm/wrappers.py b/hindsight-integrations/litellm/hindsight_litellm/wrappers.py index 5fac40a5a..268b4f5c1 100644 --- a/hindsight-integrations/litellm/hindsight_litellm/wrappers.py +++ b/hindsight-integrations/litellm/hindsight_litellm/wrappers.py @@ -104,6 +104,10 @@ def recall( budget: Optional[str] = None, max_tokens: Optional[int] = None, hindsight_api_url: Optional[str] = None, + include_entities: Optional[bool] = None, + trace: Optional[bool] = None, + recall_tags: Optional[List[str]] = None, + recall_tags_match: Optional[str] = None, ) -> RecallResponse: """Recall memories from Hindsight. @@ -118,6 +122,10 @@ def recall( budget: Recall budget level (low, mid, high) - controls how many memories are returned max_tokens: Maximum tokens for memory context hindsight_api_url: Override the configured API URL + include_entities: Include entity observations in results (default: from config) + trace: Enable trace info for debugging (default: from config) + recall_tags: Tags to filter by when recalling memories + recall_tags_match: Tag matching mode - any/all/any_strict/all_strict (default: from config) Returns: RecallResponse containing matched memories (iterable like a list). @@ -128,7 +136,7 @@ def recall( Example: >>> from hindsight_litellm import configure, recall - >>> configure(bank_id="my-agent", hindsight_api_url="http://localhost:8888") + >>> configure(bank_id="my-agent") >>> >>> # Query memories >>> memories = recall("what projects am I working on?") @@ -136,11 +144,8 @@ def recall( ... print(f"- [{m.fact_type}] {m.text}") - [world] User is building a FastAPI project >>> - >>> # With verbose mode, access debug info - >>> configure(bank_id="my-agent", verbose=True) - >>> memories = recall("what projects am I working on?") - >>> if memories.debug: - ... print(f"Queried bank: {memories.debug.bank_id}") + >>> # Filter by tags + >>> memories = recall("preferences", recall_tags=["user:alice"], recall_tags_match="any_strict") """ # Get config and defaults, or use overrides config = get_config() @@ -151,6 +156,10 @@ def recall( target_fact_types = fact_types or (defaults.fact_types if defaults else None) target_budget = budget or (defaults.budget if defaults else "mid") target_max_tokens = max_tokens or (defaults.max_memory_tokens if defaults else 4096) + target_include_entities = include_entities if include_entities is not None else (defaults.include_entities if defaults else True) + target_trace = trace if trace is not None else (defaults.trace if defaults else False) + target_recall_tags = recall_tags or (defaults.recall_tags if defaults else None) + target_recall_tags_match = recall_tags_match or (defaults.recall_tags_match if defaults else "any") if not api_url or not target_bank_id: raise RuntimeError("Hindsight not configured. Call configure() or provide bank_id and hindsight_api_url.") @@ -161,13 +170,19 @@ def recall( client = _get_client(api_url, config.api_key if config else None) # Call recall API - results = client.recall( - bank_id=target_bank_id, - query=query, - types=target_fact_types, - budget=target_budget, - max_tokens=target_max_tokens, - ) + recall_kwargs: dict = { + "bank_id": target_bank_id, + "query": query, + "types": target_fact_types, + "budget": target_budget, + "max_tokens": target_max_tokens, + "trace": target_trace, + "include_entities": target_include_entities, + } + if target_recall_tags: + recall_kwargs["tags"] = target_recall_tags + recall_kwargs["tags_match"] = target_recall_tags_match + results = client.recall(**recall_kwargs) # Convert to RecallResult objects recall_results = [] @@ -283,6 +298,8 @@ def reflect( context: Optional[str] = None, response_schema: Optional[dict] = None, hindsight_api_url: Optional[str] = None, + recall_tags: Optional[List[str]] = None, + recall_tags_match: Optional[str] = None, ) -> ReflectResult: """Generate a contextual answer based on memories. @@ -319,6 +336,8 @@ def reflect( api_url = hindsight_api_url or (config.hindsight_api_url if config else None) target_bank_id = bank_id or (defaults.bank_id if defaults else None) target_budget = budget or (defaults.budget if defaults else "mid") + target_recall_tags = recall_tags or (defaults.recall_tags if defaults else None) + target_recall_tags_match = recall_tags_match or (defaults.recall_tags_match if defaults else "any") if not api_url or not target_bank_id: raise RuntimeError("Hindsight not configured. Call configure() or provide bank_id and hindsight_api_url.") @@ -329,7 +348,7 @@ def reflect( client = _get_client(api_url, config.api_key if config else None) # Call reflect API - reflect_kwargs = { + reflect_kwargs: dict = { "bank_id": target_bank_id, "query": query, "budget": target_budget, @@ -338,6 +357,9 @@ def reflect( reflect_kwargs["context"] = context if response_schema is not None: reflect_kwargs["response_schema"] = response_schema + if target_recall_tags: + reflect_kwargs["tags"] = target_recall_tags + reflect_kwargs["tags_match"] = target_recall_tags_match result = client.reflect(**reflect_kwargs) # Convert to ReflectResult diff --git a/hindsight-integrations/litellm/tests/test_integration.py b/hindsight-integrations/litellm/tests/test_integration.py index d938ec382..2dbc926b8 100644 --- a/hindsight-integrations/litellm/tests/test_integration.py +++ b/hindsight-integrations/litellm/tests/test_integration.py @@ -1002,3 +1002,201 @@ def __init__(self, content): call_kwargs = mock_hindsight_client.retain.call_args[1] assert "USER: Hello" in call_kwargs["content"] assert "ASSISTANT: Hello world!" in call_kwargs["content"] + + +class TestRecallNewParams: + """Tests for new parameters added to recall().""" + + def setup_method(self): + reset_config() + + def teardown_method(self): + cleanup() + + def test_recall_passes_include_entities(self): + """recall() should forward include_entities to the client.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import recall + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test-agent") + + mock_client = MagicMock() + mock_client.recall.return_value = [] + + with patch("hindsight_litellm.wrappers._get_client", return_value=mock_client): + recall("test query", include_entities=False) + call_kwargs = mock_client.recall.call_args[1] + assert call_kwargs["include_entities"] is False + + def test_recall_passes_trace(self): + """recall() should forward trace to the client.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import recall + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test-agent") + + mock_client = MagicMock() + mock_client.recall.return_value = [] + + with patch("hindsight_litellm.wrappers._get_client", return_value=mock_client): + recall("test query", trace=True) + call_kwargs = mock_client.recall.call_args[1] + assert call_kwargs["trace"] is True + + def test_recall_passes_recall_tags(self): + """recall() should forward recall_tags and recall_tags_match to the client.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import recall + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test-agent") + + mock_client = MagicMock() + mock_client.recall.return_value = [] + + with patch("hindsight_litellm.wrappers._get_client", return_value=mock_client): + recall("test query", recall_tags=["user:alice"], recall_tags_match="any_strict") + call_kwargs = mock_client.recall.call_args[1] + assert call_kwargs["tags"] == ["user:alice"] + assert call_kwargs["tags_match"] == "any_strict" + + def test_recall_no_tags_key_when_empty(self): + """recall() should not pass tags key when recall_tags is None.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import recall + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test-agent") + + mock_client = MagicMock() + mock_client.recall.return_value = [] + + with patch("hindsight_litellm.wrappers._get_client", return_value=mock_client): + recall("test query") + call_kwargs = mock_client.recall.call_args[1] + assert "tags" not in call_kwargs + + def test_recall_inherits_include_entities_from_defaults(self): + """recall() should inherit include_entities from set_defaults() if not overridden.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import recall + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test-agent", include_entities=False) + + mock_client = MagicMock() + mock_client.recall.return_value = [] + + with patch("hindsight_litellm.wrappers._get_client", return_value=mock_client): + recall("test query") + call_kwargs = mock_client.recall.call_args[1] + assert call_kwargs["include_entities"] is False + + +class TestReflectNewParams: + """Tests for new parameters added to reflect().""" + + def setup_method(self): + reset_config() + + def teardown_method(self): + cleanup() + + def test_reflect_passes_recall_tags(self): + """reflect() should forward recall_tags and recall_tags_match to the client.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import reflect + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test-agent") + + mock_result = MagicMock() + mock_result.text = "some reflection" + mock_result.based_on = None + mock_client = MagicMock() + mock_client.reflect.return_value = mock_result + + with patch("hindsight_litellm.wrappers._get_client", return_value=mock_client): + reflect("test query", recall_tags=["user:bob"], recall_tags_match="all") + call_kwargs = mock_client.reflect.call_args[1] + assert call_kwargs["tags"] == ["user:bob"] + assert call_kwargs["tags_match"] == "all" + + def test_reflect_no_tags_key_when_empty(self): + """reflect() should not pass tags key when recall_tags is None.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import reflect + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test-agent") + + mock_result = MagicMock() + mock_result.text = "some reflection" + mock_result.based_on = None + mock_client = MagicMock() + mock_client.reflect.return_value = mock_result + + with patch("hindsight_litellm.wrappers._get_client", return_value=mock_client): + reflect("test query") + call_kwargs = mock_client.reflect.call_args[1] + assert "tags" not in call_kwargs + + +class TestHindsightMemoryNewParams: + """Tests for new parameters added to hindsight_memory() context manager.""" + + def setup_method(self): + cleanup() + + def teardown_method(self): + cleanup() + + def test_hindsight_memory_session_id(self): + """hindsight_memory() should pass session_id through to defaults.""" + from hindsight_litellm import hindsight_memory + + with hindsight_memory(bank_id="test-agent", session_id="conv-123"): + defaults = get_defaults() + assert defaults.session_id == "conv-123" + + def test_hindsight_memory_use_reflect(self): + """hindsight_memory() should pass use_reflect through to defaults.""" + from hindsight_litellm import hindsight_memory + + with hindsight_memory(bank_id="test-agent", use_reflect=True): + defaults = get_defaults() + assert defaults.use_reflect is True + + def test_hindsight_memory_tags(self): + """hindsight_memory() should pass tags and recall_tags through to defaults.""" + from hindsight_litellm import hindsight_memory + + with hindsight_memory( + bank_id="test-agent", + tags=["session:abc"], + recall_tags=["session:abc"], + recall_tags_match="any_strict", + ): + defaults = get_defaults() + assert defaults.tags == ["session:abc"] + assert defaults.recall_tags == ["session:abc"] + assert defaults.recall_tags_match == "any_strict" + + def test_hindsight_memory_reflect_context(self): + """hindsight_memory() should pass reflect_context through to defaults.""" + from hindsight_litellm import hindsight_memory + + with hindsight_memory(bank_id="test-agent", reflect_context="I am an assistant."): + defaults = get_defaults() + assert defaults.reflect_context == "I am an assistant." + + def test_hindsight_memory_default_url_is_cloud(self): + """hindsight_memory() default URL should be the cloud endpoint, not localhost.""" + from hindsight_litellm import hindsight_memory + from hindsight_litellm.config import DEFAULT_HINDSIGHT_API_URL + + with hindsight_memory(bank_id="test-agent"): + config = get_config() + assert config.hindsight_api_url == DEFAULT_HINDSIGHT_API_URL diff --git a/hindsight-integrations/litellm/uv.lock b/hindsight-integrations/litellm/uv.lock index 4bc635130..b804621cd 100644 --- a/hindsight-integrations/litellm/uv.lock +++ b/hindsight-integrations/litellm/uv.lock @@ -604,7 +604,7 @@ wheels = [ [[package]] name = "hindsight-litellm" -version = "0.5.2" +version = "0.5.3" source = { editable = "." } dependencies = [ { name = "aiohttp" }, From da39dd8422b1a965031dfe6f796fd0305a829f43 Mon Sep 17 00:00:00 2001 From: DK09876 Date: Thu, 21 May 2026 16:29:03 -0700 Subject: [PATCH 2/4] fix(litellm): strip hindsight_bank_id from kwargs before LiteLLM call and add sync param to aretain - hindsight_bank_id kwarg was leaking into LiteLLM as extra_body, causing OpenAI 400 errors; now popped in completion(), _wrapped_completion(), _wrapped_acompletion() and propagated as bank_id_override throughout injection and storage paths - _inject_memories() accepts bank_id_override to honour per-call bank without mutating globals - _store_conversation() and _store_conversation_from_text() accept bank_id_override for consistent per-call storage routing - _LiteLLMStreamWrapper and _LiteLLMAsyncStreamWrapper carry bank_id_override so streamed responses store to the right bank - aretain() now accepts sync=True, forwarding it to retain() Co-Authored-By: Claude Sonnet 4.6 --- .../litellm/hindsight_litellm/__init__.py | 101 ++++++++++-------- .../litellm/hindsight_litellm/wrappers.py | 2 + 2 files changed, 61 insertions(+), 42 deletions(-) diff --git a/hindsight-integrations/litellm/hindsight_litellm/__init__.py b/hindsight-integrations/litellm/hindsight_litellm/__init__.py index b777e1590..a70f1abbe 100644 --- a/hindsight-integrations/litellm/hindsight_litellm/__init__.py +++ b/hindsight-integrations/litellm/hindsight_litellm/__init__.py @@ -226,6 +226,7 @@ def _inject_memories( messages: List[dict], custom_query: Optional[str] = None, custom_reflect_context: Optional[str] = None, + bank_id_override: Optional[str] = None, ) -> List[dict]: """Inject memories into messages list. @@ -238,6 +239,7 @@ def _inject_memories( messages: List of message dicts to inject memories into custom_query: Optional custom query to use for memory lookup instead of user message custom_reflect_context: Optional context to pass to reflect API (overrides defaults.reflect_context) + bank_id_override: Optional bank_id that overrides the default for this call """ global _last_injection_debug import logging @@ -251,7 +253,7 @@ def _inject_memories( if not config or not config.inject_memories: return messages - if not defaults or not defaults.bank_id: + if not bank_id_override and (not defaults or not defaults.bank_id): raise ValueError( "No bank_id configured. Either call set_defaults(bank_id=...) " "or pass hindsight_bank_id=... to the completion call." @@ -283,8 +285,8 @@ def _inject_memories( if not user_query: return messages - # Use bank_id from defaults - bank_id = defaults.bank_id + # Use bank_id_override if provided, otherwise fall back to defaults + bank_id = bank_id_override or (defaults.bank_id if defaults else None) # Track debug info mode = "reflect" if defaults.use_reflect else "recall" @@ -529,9 +531,10 @@ def _wrapped_completion(*args, **kwargs): """ config = get_config() - # Extract hindsight-specific kwargs + # Extract hindsight-specific kwargs (must be popped before calling LiteLLM) custom_query = kwargs.pop("hindsight_query", None) custom_reflect_context = kwargs.pop("hindsight_reflect_context", None) + bank_id_override = kwargs.pop("hindsight_bank_id", None) # Extract messages from kwargs or args messages = kwargs.get("messages") @@ -549,6 +552,7 @@ def _wrapped_completion(*args, **kwargs): messages, custom_query=custom_query, custom_reflect_context=custom_reflect_context, + bank_id_override=bank_id_override, ) kwargs["messages"] = injected_messages except Exception as e: @@ -562,8 +566,8 @@ def _wrapped_completion(*args, **kwargs): final_messages = kwargs.get("messages", messages) if final_messages: if _is_streaming_response(response): - return _LiteLLMStreamWrapper(response, final_messages, model or "unknown") - _store_conversation(final_messages, response, model or "unknown") + return _LiteLLMStreamWrapper(response, final_messages, model or "unknown", bank_id_override=bank_id_override) + _store_conversation(final_messages, response, model or "unknown", bank_id_override=bank_id_override) return response @@ -578,9 +582,10 @@ async def _wrapped_acompletion(*args, **kwargs): """ config = get_config() - # Extract hindsight-specific kwargs + # Extract hindsight-specific kwargs (must be popped before calling LiteLLM) custom_query = kwargs.pop("hindsight_query", None) custom_reflect_context = kwargs.pop("hindsight_reflect_context", None) + bank_id_override = kwargs.pop("hindsight_bank_id", None) # Extract messages from kwargs or args messages = kwargs.get("messages") @@ -598,6 +603,7 @@ async def _wrapped_acompletion(*args, **kwargs): messages, custom_query=custom_query, custom_reflect_context=custom_reflect_context, + bank_id_override=bank_id_override, ) kwargs["messages"] = injected_messages except Exception as e: @@ -611,8 +617,8 @@ async def _wrapped_acompletion(*args, **kwargs): final_messages = kwargs.get("messages", messages) if final_messages: if _is_streaming_response(response): - return _LiteLLMAsyncStreamWrapper(response, final_messages, model or "unknown") - _store_conversation(final_messages, response, model or "unknown") + return _LiteLLMAsyncStreamWrapper(response, final_messages, model or "unknown", bank_id_override=bank_id_override) + _store_conversation(final_messages, response, model or "unknown", bank_id_override=bank_id_override) return response @@ -799,10 +805,11 @@ def _format_messages_for_storage(messages: List[dict]) -> List[str]: class _LiteLLMStreamWrapper: """Wraps a LiteLLM sync streaming response to collect chunks and store conversation on completion.""" - def __init__(self, stream, messages: List[dict], model: str): + def __init__(self, stream, messages: List[dict], model: str, bank_id_override: Optional[str] = None): self._stream = stream self._messages = messages self._model = model + self._bank_id_override = bank_id_override self._collected_content: List[str] = [] self._finished = False @@ -844,7 +851,7 @@ def _store_if_needed(self): items.append(f"ASSISTANT: {assistant_output}") conversation_text = "\n\n".join(items) if conversation_text: - _store_conversation_from_text(conversation_text, self._model) + _store_conversation_from_text(conversation_text, self._model, bank_id_override=self._bank_id_override) def close(self): self._store_if_needed() @@ -858,10 +865,11 @@ def __getattr__(self, name: str): class _LiteLLMAsyncStreamWrapper: """Wraps a LiteLLM async streaming response to collect chunks and store conversation on completion.""" - def __init__(self, stream, messages: List[dict], model: str): + def __init__(self, stream, messages: List[dict], model: str, bank_id_override: Optional[str] = None): self._stream = stream self._messages = messages self._model = model + self._bank_id_override = bank_id_override self._collected_content: List[str] = [] self._finished = False @@ -903,7 +911,7 @@ def _store_if_needed(self): items.append(f"ASSISTANT: {assistant_output}") conversation_text = "\n\n".join(items) if conversation_text: - _store_conversation_from_text(conversation_text, self._model) + _store_conversation_from_text(conversation_text, self._model, bank_id_override=self._bank_id_override) async def aclose(self): self._store_if_needed() @@ -1073,7 +1081,7 @@ def get_pending_storage_errors() -> List[Exception]: return errors -def _store_conversation_from_text(conversation_text: str, model: str) -> None: +def _store_conversation_from_text(conversation_text: str, model: str, bank_id_override: Optional[str] = None) -> None: """Store pre-formatted conversation text to Hindsight. Used by stream wrappers which collect chunks and format the conversation themselves. @@ -1084,7 +1092,8 @@ def _store_conversation_from_text(conversation_text: str, model: str) -> None: if not config or not config.store_conversations: return - if not defaults or not defaults.bank_id: + effective_bank_id = bank_id_override or (defaults.bank_id if defaults else None) + if not effective_bank_id: _storage_logger.warning("No bank_id configured for storage. Call set_defaults(bank_id=...).") return if not conversation_text: @@ -1093,23 +1102,24 @@ def _store_conversation_from_text(conversation_text: str, model: str) -> None: if config.sync_storage: try: content_to_store = conversation_text - if defaults.effective_document_id: + effective_doc_id = defaults.effective_document_id if defaults else None + if effective_doc_id: existing_content = _get_existing_document_content( - defaults.bank_id, defaults.effective_document_id, config.verbose + effective_bank_id, effective_doc_id, config.verbose ) if existing_content: content_to_store = f"{existing_content}\n\n{conversation_text}" retain( content=content_to_store, - bank_id=defaults.bank_id, + bank_id=effective_bank_id, context=f"conversation:litellm:{model}", - document_id=defaults.effective_document_id, - tags=defaults.tags, + document_id=effective_doc_id, + tags=defaults.tags if defaults else None, metadata={"source": "litellm", "model": model}, ) if config.verbose: - _storage_logger.info(f"Stored streamed conversation to bank: {defaults.bank_id}") + _storage_logger.info(f"Stored streamed conversation to bank: {effective_bank_id}") except Exception as e: raise HindsightError(f"Failed to store conversation: {e}") from e return @@ -1118,9 +1128,9 @@ def _store_conversation_from_text(conversation_text: str, model: str) -> None: target=_store_conversation_sync, args=( conversation_text, - defaults.bank_id, - defaults.effective_document_id, - defaults.tags, + effective_bank_id, + defaults.effective_document_id if defaults else None, + defaults.tags if defaults else None, model, config.verbose, ), @@ -1133,6 +1143,7 @@ def _store_conversation( messages: List[dict], response, model: str, + bank_id_override: Optional[str] = None, ) -> None: """Store conversation to Hindsight. @@ -1146,7 +1157,8 @@ def _store_conversation( if not config or not config.store_conversations: return - if not defaults or not defaults.bank_id: + effective_bank_id = bank_id_override or (defaults.bank_id if defaults else None) + if not effective_bank_id: _storage_logger.warning("No bank_id configured for storage. Call set_defaults(bank_id=...).") return @@ -1156,30 +1168,31 @@ def _store_conversation( if not conversation_text: return + effective_doc_id = defaults.effective_document_id if defaults else None + # Sync mode: run directly and raise errors if config.sync_storage: try: - # If document_id is set, fetch existing content and append content_to_store = conversation_text - if defaults.effective_document_id: + if effective_doc_id: existing_content = _get_existing_document_content( - defaults.bank_id, defaults.effective_document_id, config.verbose + effective_bank_id, effective_doc_id, config.verbose ) if existing_content: content_to_store = f"{existing_content}\n\n{conversation_text}" if config.verbose: - _storage_logger.debug(f"Appending to existing document: {defaults.effective_document_id}") + _storage_logger.debug(f"Appending to existing document: {effective_doc_id}") retain( content=content_to_store, - bank_id=defaults.bank_id, + bank_id=effective_bank_id, context=f"conversation:litellm:{model}", - document_id=defaults.effective_document_id, - tags=defaults.tags, + document_id=effective_doc_id, + tags=defaults.tags if defaults else None, metadata={"source": "litellm", "model": model}, ) if config.verbose: - _storage_logger.info(f"Stored conversation to bank: {defaults.bank_id}") + _storage_logger.info(f"Stored conversation to bank: {effective_bank_id}") except Exception as e: raise HindsightError(f"Failed to store conversation: {e}") from e return @@ -1189,9 +1202,9 @@ def _store_conversation( target=_store_conversation_sync, args=( conversation_text, - defaults.bank_id, - defaults.effective_document_id, - defaults.tags, + effective_bank_id, + effective_doc_id, + defaults.tags if defaults else None, model, config.verbose, ), @@ -1250,9 +1263,10 @@ def completion(*args, **kwargs): """ config = get_config() - # Extract hindsight-specific kwargs + # Extract hindsight-specific kwargs (must be popped before calling LiteLLM) custom_query = kwargs.pop("hindsight_query", None) custom_reflect_context = kwargs.pop("hindsight_reflect_context", None) + bank_id_override = kwargs.pop("hindsight_bank_id", None) # Extract messages from kwargs or args messages = kwargs.get("messages") @@ -1270,6 +1284,7 @@ def completion(*args, **kwargs): messages, custom_query=custom_query, custom_reflect_context=custom_reflect_context, + bank_id_override=bank_id_override, ) kwargs["messages"] = injected_messages except Exception as e: @@ -1283,8 +1298,8 @@ def completion(*args, **kwargs): final_messages = kwargs.get("messages", messages) if final_messages: if _is_streaming_response(response): - return _LiteLLMStreamWrapper(response, final_messages, model or "unknown") - _store_conversation(final_messages, response, model or "unknown") + return _LiteLLMStreamWrapper(response, final_messages, model or "unknown", bank_id_override=bank_id_override) + _store_conversation(final_messages, response, model or "unknown", bank_id_override=bank_id_override) return response @@ -1328,9 +1343,10 @@ async def acompletion(*args, **kwargs): """ config = get_config() - # Extract hindsight-specific kwargs + # Extract hindsight-specific kwargs (must be popped before calling LiteLLM) custom_query = kwargs.pop("hindsight_query", None) custom_reflect_context = kwargs.pop("hindsight_reflect_context", None) + bank_id_override = kwargs.pop("hindsight_bank_id", None) # Extract messages from kwargs or args messages = kwargs.get("messages") @@ -1348,6 +1364,7 @@ async def acompletion(*args, **kwargs): messages, custom_query=custom_query, custom_reflect_context=custom_reflect_context, + bank_id_override=bank_id_override, ) kwargs["messages"] = injected_messages except Exception as e: @@ -1361,8 +1378,8 @@ async def acompletion(*args, **kwargs): final_messages = kwargs.get("messages", messages) if final_messages: if _is_streaming_response(response): - return _LiteLLMAsyncStreamWrapper(response, final_messages, model or "unknown") - _store_conversation(final_messages, response, model or "unknown") + return _LiteLLMAsyncStreamWrapper(response, final_messages, model or "unknown", bank_id_override=bank_id_override) + _store_conversation(final_messages, response, model or "unknown", bank_id_override=bank_id_override) return response diff --git a/hindsight-integrations/litellm/hindsight_litellm/wrappers.py b/hindsight-integrations/litellm/hindsight_litellm/wrappers.py index 268b4f5c1..9dd2b7343 100644 --- a/hindsight-integrations/litellm/hindsight_litellm/wrappers.py +++ b/hindsight-integrations/litellm/hindsight_litellm/wrappers.py @@ -673,6 +673,7 @@ async def aretain( tags: Optional[List[str]] = None, metadata: Optional[Dict[str, str]] = None, hindsight_api_url: Optional[str] = None, + sync: bool = False, ) -> RetainResult: """Async version of retain(). @@ -691,6 +692,7 @@ async def aretain( tags=tags, metadata=metadata, hindsight_api_url=hindsight_api_url, + sync=sync, ), ) From 47ff864d00b07f7e69d57eaa936d482253e5f1a2 Mon Sep 17 00:00:00 2001 From: DK09876 Date: Fri, 22 May 2026 08:16:13 -0700 Subject: [PATCH 3/4] =?UTF-8?q?fix(litellm):=20design=20review=20fixes=20?= =?UTF-8?q?=E2=80=94=20injection=5Fmode,=20context=20manager=20restore,=20?= =?UTF-8?q?validation,=20error=20consistency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - config.py: remove DEFAULT_BANK_ID footgun (configure() without bank_id now leaves bank_id=None; is_configured() and enable() correctly require explicit bank_id). Add _restore_config() for atomic state restoration. Add budget/recall_tags_match validation in configure() and set_defaults(). Emit DeprecationWarning for document_id usage. - __init__.py: _inject_memories() now respects injection_mode (PREPEND_USER prepends to last user message; SYSTEM_MESSAGE keeps existing behaviour). Wire up defaults.query as fallback recall query. Fix ValueError → HindsightError for missing bank_id. hindsight_memory() finally block now calls _restore_config() to atomically restore all settings (previously lost: sync_storage, tags, recall_tags, recall_tags_match, reflect_context, reflect_response_schema). Add _enabled_lock and _debug_lock for thread safety on shared mutable state. - callbacks.py: ValueError → HindsightError in log_pre_api_call and async_log_pre_api_call for missing bank_id, consistent with __init__.py. - tests: update tests that relied on DEFAULT_BANK_ID behaviour; add TestValidation, TestInjectionMode, TestQueryField, TestHindsightErrorConsistency, TestContextManagerFullRestore (83 tests, all passing). Co-Authored-By: Claude Sonnet 4.6 --- .../litellm/hindsight_litellm/__init__.py | 152 ++++---- .../litellm/hindsight_litellm/callbacks.py | 4 +- .../litellm/hindsight_litellm/config.py | 40 +- .../litellm/tests/test_integration.py | 355 +++++++++++++++++- 4 files changed, 466 insertions(+), 85 deletions(-) diff --git a/hindsight-integrations/litellm/hindsight_litellm/__init__.py b/hindsight-integrations/litellm/hindsight_litellm/__init__.py index a70f1abbe..298f302e1 100644 --- a/hindsight-integrations/litellm/hindsight_litellm/__init__.py +++ b/hindsight-integrations/litellm/hindsight_litellm/__init__.py @@ -118,6 +118,7 @@ HindsightConfig, HindsightDefaults, MemoryInjectionMode, + _restore_config, configure, get_config, get_defaults, @@ -152,11 +153,15 @@ # Track whether we've registered with LiteLLM _enabled = False +_enabled_lock = threading.Lock() # Store original functions for restoration _original_completion = None _original_acompletion = None +# Lock protecting _last_injection_debug writes +_debug_lock = threading.Lock() + @dataclass class InjectionDebugInfo: @@ -254,7 +259,7 @@ def _inject_memories( return messages if not bank_id_override and (not defaults or not defaults.bank_id): - raise ValueError( + raise HindsightError( "No bank_id configured. Either call set_defaults(bank_id=...) " "or pass hindsight_bank_id=... to the completion call." ) @@ -262,9 +267,11 @@ def _inject_memories( if not messages: return messages - # Use custom_query if provided, otherwise fall back to the last user message + # Resolve query: custom_query arg > defaults.query > last user message if custom_query: user_query = custom_query + elif defaults and defaults.query: + user_query = defaults.query else: user_query = None for msg in reversed(messages): @@ -447,23 +454,41 @@ def _inject_memories( "The following information from memory may be relevant:\n\n" + "\n".join(memory_lines) ) - # Inject into messages + # Inject into messages based on injection_mode updated_messages = list(messages) + injection_mode = defaults.injection_mode if defaults else MemoryInjectionMode.SYSTEM_MESSAGE + + if injection_mode == MemoryInjectionMode.PREPEND_USER: + # Prepend memory context to the last user message + for i in range(len(updated_messages) - 1, -1, -1): + if updated_messages[i].get("role") == "user": + existing_content = updated_messages[i].get("content", "") + if isinstance(existing_content, str): + updated_messages[i] = { + **updated_messages[i], + "content": f"{memory_context}\n\n{existing_content}", + } + elif isinstance(existing_content, list): + updated_messages[i] = { + **updated_messages[i], + "content": [{"type": "text", "text": memory_context}] + existing_content, + } + break + else: + # SYSTEM_MESSAGE mode (default): add to/create system message + found_system = False + for i, msg in enumerate(updated_messages): + if msg.get("role") == "system": + existing_content = msg.get("content", "") + updated_messages[i] = { + **msg, + "content": f"{existing_content}\n\n{memory_context}", + } + found_system = True + break - # Find existing system message or create new one - found_system = False - for i, msg in enumerate(updated_messages): - if msg.get("role") == "system": - existing_content = msg.get("content", "") - updated_messages[i] = { - **msg, - "content": f"{existing_content}\n\n{memory_context}", - } - found_system = True - break - - if not found_system: - updated_messages.insert(0, {"role": "system", "content": memory_context}) + if not found_system: + updated_messages.insert(0, {"role": "system", "content": memory_context}) # Store debug info when verbose if config.verbose: @@ -664,28 +689,30 @@ def enable() -> None: """ global _enabled, _original_completion, _original_acompletion - if _enabled: - return # Already enabled + with _enabled_lock: + if _enabled: + return # Already enabled - config = get_config() - defaults = get_defaults() + config = get_config() + defaults = get_defaults() - if not config: - raise RuntimeError("Hindsight not configured. Call configure() before enable().") + if not config: + raise RuntimeError("Hindsight not configured. Call configure() before enable().") - if not defaults or not defaults.bank_id: - raise RuntimeError("Hindsight bank_id not set. Call set_defaults(bank_id=...) before enable().") + if not defaults or not defaults.bank_id: + raise RuntimeError("Hindsight bank_id not set. Call set_defaults(bank_id=...) before enable().") - # Store original functions and monkeypatch for memory injection + storage - _original_completion = litellm.completion - _original_acompletion = litellm.acompletion - litellm.completion = _wrapped_completion - litellm.acompletion = _wrapped_acompletion + # Store original functions and monkeypatch for memory injection + storage + _original_completion = litellm.completion + _original_acompletion = litellm.acompletion + litellm.completion = _wrapped_completion + litellm.acompletion = _wrapped_acompletion - _enabled = True + _enabled = True - if config.verbose: - print(f"Hindsight memory enabled for bank: {defaults.bank_id}") + if get_config() and get_config().verbose: + defaults = get_defaults() + print(f"Hindsight memory enabled for bank: {defaults.bank_id if defaults else 'unknown'}") def disable() -> None: @@ -700,21 +727,22 @@ def disable() -> None: """ global _enabled, _original_completion, _original_acompletion - if not _enabled: - return # Already disabled + with _enabled_lock: + if not _enabled: + return # Already disabled - # Restore original functions - if _original_completion is not None: - litellm.completion = _original_completion - _original_completion = None - if _original_acompletion is not None: - litellm.acompletion = _original_acompletion - _original_acompletion = None + # Restore original functions + if _original_completion is not None: + litellm.completion = _original_completion + _original_completion = None + if _original_acompletion is not None: + litellm.acompletion = _original_acompletion + _original_acompletion = None - # Close cached HTTP client to avoid "Unclosed client session" warnings - _close_client() + # Close cached HTTP client to avoid "Unclosed client session" warnings + _close_client() - _enabled = False + _enabled = False config = get_config() if config and config.verbose: @@ -1485,36 +1513,12 @@ def hindsight_memory( enable() yield finally: - # Restore previous state + # Atomically restore previous state, bypassing all side effects (warnings, + # bank creation, validation) that configure/set_defaults would trigger. disable() - if previous_config: - configure( - hindsight_api_url=previous_config.hindsight_api_url, - api_key=previous_config.api_key, - store_conversations=previous_config.store_conversations, - inject_memories=previous_config.inject_memories, - injection_mode=previous_config.injection_mode, - excluded_models=previous_config.excluded_models, - verbose=previous_config.verbose, - ) - if previous_defaults: - set_defaults( - bank_id=previous_defaults.bank_id, - session_id=previous_defaults.session_id, - document_id=previous_defaults.document_id, - budget=previous_defaults.budget, - fact_types=previous_defaults.fact_types, - max_memories=previous_defaults.max_memories, - max_memory_tokens=previous_defaults.max_memory_tokens, - use_reflect=previous_defaults.use_reflect, - reflect_include_facts=previous_defaults.reflect_include_facts, - include_entities=previous_defaults.include_entities, - trace=previous_defaults.trace, - ) - if was_enabled: - enable() - else: - reset_config() + _restore_config(previous_config) + if was_enabled and previous_config is not None: + enable() __all__ = [ diff --git a/hindsight-integrations/litellm/hindsight_litellm/callbacks.py b/hindsight-integrations/litellm/hindsight_litellm/callbacks.py index 506a0f9d9..5753aa8f8 100644 --- a/hindsight-integrations/litellm/hindsight_litellm/callbacks.py +++ b/hindsight-integrations/litellm/hindsight_litellm/callbacks.py @@ -703,7 +703,7 @@ def log_pre_api_call( # Get effective settings (kwargs override defaults) settings = self._get_effective_settings(kwargs) if not settings.bank_id: - raise ValueError( + raise HindsightError( "No bank_id configured. Either call set_defaults(bank_id=...) " "or pass hindsight_bank_id=... to the completion call." ) @@ -762,7 +762,7 @@ async def async_log_pre_api_call( # Get effective settings (kwargs override defaults) settings = self._get_effective_settings(kwargs) if not settings.bank_id: - raise ValueError( + raise HindsightError( "No bank_id configured. Either call set_defaults(bank_id=...) " "or pass hindsight_bank_id=... to the completion call." ) diff --git a/hindsight-integrations/litellm/hindsight_litellm/config.py b/hindsight-integrations/litellm/hindsight_litellm/config.py index 9ad66f190..4895c01f4 100644 --- a/hindsight-integrations/litellm/hindsight_litellm/config.py +++ b/hindsight-integrations/litellm/hindsight_litellm/config.py @@ -15,6 +15,7 @@ """ import os +import warnings from dataclasses import asdict, dataclass, field, fields from enum import Enum from importlib import metadata @@ -31,6 +32,9 @@ DEFAULT_BANK_ID = "default" HINDSIGHT_API_KEY_ENV = "HINDSIGHT_API_KEY" +VALID_BUDGETS = frozenset({"low", "mid", "high"}) +VALID_TAGS_MATCH = frozenset({"any", "all", "any_strict", "all_strict"}) + class MemoryInjectionMode(str, Enum): """How memories should be injected into the prompt. @@ -317,10 +321,22 @@ def configure( """ global _global_config + # Validate per-call defaults + if budget not in VALID_BUDGETS: + raise ValueError(f"budget must be one of {sorted(VALID_BUDGETS)!r}, got {budget!r}") + if recall_tags_match not in VALID_TAGS_MATCH: + raise ValueError(f"recall_tags_match must be one of {sorted(VALID_TAGS_MATCH)!r}, got {recall_tags_match!r}") + if document_id is not None: + warnings.warn( + "document_id is deprecated; use session_id instead.", + DeprecationWarning, + stacklevel=2, + ) + # Apply connection-level defaults resolved_api_url = hindsight_api_url or DEFAULT_HINDSIGHT_API_URL resolved_api_key = api_key or os.environ.get(HINDSIGHT_API_KEY_ENV) - resolved_bank_id = bank_id or DEFAULT_BANK_ID + resolved_bank_id = bank_id # Build default settings default_settings = HindsightCallSettings( @@ -429,6 +445,18 @@ def set_defaults( """ global _global_config + # Validate values when explicitly provided + if budget is not None and budget not in VALID_BUDGETS: + raise ValueError(f"budget must be one of {sorted(VALID_BUDGETS)!r}, got {budget!r}") + if recall_tags_match is not None and recall_tags_match not in VALID_TAGS_MATCH: + raise ValueError(f"recall_tags_match must be one of {sorted(VALID_TAGS_MATCH)!r}, got {recall_tags_match!r}") + if document_id is not None: + warnings.warn( + "document_id is deprecated; use session_id instead.", + DeprecationWarning, + stacklevel=2, + ) + # Ensure configure() was called if _global_config is None: # Auto-configure with defaults if not configured @@ -556,3 +584,13 @@ def reset_config() -> None: """Reset all global configuration to None.""" global _global_config _global_config = None + + +def _restore_config(saved_config: Optional[HindsightConfig]) -> None: + """Directly restore global config from a saved snapshot, bypassing all side effects. + + Used by hindsight_memory() context manager to atomically restore state on exit + without triggering warnings, bank creation, or other configure() side effects. + """ + global _global_config + _global_config = saved_config diff --git a/hindsight-integrations/litellm/tests/test_integration.py b/hindsight-integrations/litellm/tests/test_integration.py index 2dbc926b8..0c7605ab3 100644 --- a/hindsight-integrations/litellm/tests/test_integration.py +++ b/hindsight-integrations/litellm/tests/test_integration.py @@ -13,6 +13,7 @@ get_config, is_configured, reset_config, + hindsight_memory, MemoryInjectionMode, ) from hindsight_litellm.callbacks import HindsightCallback @@ -82,9 +83,9 @@ def test_configure_with_all_options(self): assert defaults.document_id == "doc-123" def test_is_configured_with_defaults(self): - """Test is_configured returns True with default bank_id.""" - configure() # Uses default bank_id="default" - assert is_configured() is True + """configure() alone (no bank_id) leaves is_configured() False.""" + configure() + assert is_configured() is False def test_is_configured_with_bank_id_in_defaults(self): """Test is_configured returns True with bank_id in defaults.""" @@ -125,12 +126,11 @@ def test_enable_without_config_raises(self): with pytest.raises(RuntimeError, match="not configured"): enable() - def test_enable_with_default_bank_id_works(self): - """Test enable works with default bank_id (no explicit bank_id required).""" + def test_enable_without_bank_id_raises(self): + """enable() raises RuntimeError when no bank_id has been set.""" configure(hindsight_api_url="http://localhost:8888") - # Should work - configure() provides default bank_id="default" - enable() - assert is_enabled() is True + with pytest.raises(RuntimeError, match="bank_id"): + enable() def test_enable_sets_enabled_flag(self): """Test enable sets the enabled flag.""" @@ -1200,3 +1200,342 @@ def test_hindsight_memory_default_url_is_cloud(self): with hindsight_memory(bank_id="test-agent"): config = get_config() assert config.hindsight_api_url == DEFAULT_HINDSIGHT_API_URL + + +class TestValidation: + """Tests for input validation in configure() and set_defaults().""" + + def setup_method(self): + reset_config() + + def teardown_method(self): + cleanup() + + def test_configure_invalid_budget_raises(self): + """configure() raises ValueError for invalid budget.""" + with pytest.raises(ValueError, match="budget"): + configure(budget="extreme") + + def test_set_defaults_invalid_budget_raises(self): + """set_defaults() raises ValueError for invalid budget.""" + configure(hindsight_api_url="http://localhost:8888") + with pytest.raises(ValueError, match="budget"): + set_defaults(bank_id="test", budget="extreme") + + def test_configure_invalid_recall_tags_match_raises(self): + """configure() raises ValueError for invalid recall_tags_match.""" + with pytest.raises(ValueError, match="recall_tags_match"): + configure(recall_tags_match="fuzzy") + + def test_set_defaults_invalid_recall_tags_match_raises(self): + """set_defaults() raises ValueError for invalid recall_tags_match.""" + configure(hindsight_api_url="http://localhost:8888") + with pytest.raises(ValueError, match="recall_tags_match"): + set_defaults(bank_id="test", recall_tags_match="fuzzy") + + def test_configure_valid_budgets_accepted(self): + """configure() accepts all valid budget values.""" + for budget in ("low", "mid", "high"): + configure(hindsight_api_url="http://localhost:8888", budget=budget) + assert get_config().default_settings.budget == budget + + def test_configure_valid_tags_match_accepted(self): + """configure() accepts all valid recall_tags_match values.""" + for match in ("any", "all", "any_strict", "all_strict"): + configure(hindsight_api_url="http://localhost:8888", recall_tags_match=match) + assert get_config().default_settings.recall_tags_match == match + + def test_configure_document_id_emits_deprecation_warning(self): + """configure() with document_id emits DeprecationWarning.""" + with pytest.warns(DeprecationWarning, match="document_id"): + configure(hindsight_api_url="http://localhost:8888", document_id="doc-123") + + def test_set_defaults_document_id_emits_deprecation_warning(self): + """set_defaults() with document_id emits DeprecationWarning.""" + configure(hindsight_api_url="http://localhost:8888") + with pytest.warns(DeprecationWarning, match="document_id"): + set_defaults(bank_id="test", document_id="doc-123") + + +class TestInjectionMode: + """Tests for injection_mode in _inject_memories().""" + + def setup_method(self): + reset_config() + + def teardown_method(self): + cleanup() + + def test_system_message_injection_appends_to_existing_system(self): + """SYSTEM_MESSAGE mode appends memories to an existing system message.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import _inject_memories + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test", injection_mode=MemoryInjectionMode.SYSTEM_MESSAGE) + + mock_result = MagicMock() + mock_result.text = "User likes Rust" + mock_result.type = "world" + + mock_client = MagicMock() + mock_client.recall.return_value = [mock_result] + + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "What do I like?"}, + ] + + with patch("hindsight_litellm._get_client", return_value=mock_client): + result = _inject_memories(messages) + + assert result[0]["role"] == "system" + assert "You are helpful." in result[0]["content"] + assert "Rust" in result[0]["content"] + + def test_system_message_injection_creates_system_when_absent(self): + """SYSTEM_MESSAGE mode creates a new system message when none exists.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import _inject_memories + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test", injection_mode=MemoryInjectionMode.SYSTEM_MESSAGE) + + mock_result = MagicMock() + mock_result.text = "User likes Python" + mock_result.type = "world" + + mock_client = MagicMock() + mock_client.recall.return_value = [mock_result] + + messages = [{"role": "user", "content": "Hello"}] + + with patch("hindsight_litellm._get_client", return_value=mock_client): + result = _inject_memories(messages) + + assert result[0]["role"] == "system" + assert "Python" in result[0]["content"] + + def test_prepend_user_injection_prepends_to_last_user_message(self): + """PREPEND_USER mode prepends memories to the last user message.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import _inject_memories + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test", injection_mode=MemoryInjectionMode.PREPEND_USER) + + mock_result = MagicMock() + mock_result.text = "User likes Go" + mock_result.type = "world" + + mock_client = MagicMock() + mock_client.recall.return_value = [mock_result] + + messages = [{"role": "user", "content": "What do I like?"}] + + with patch("hindsight_litellm._get_client", return_value=mock_client): + result = _inject_memories(messages) + + # No system message should be created + assert all(m["role"] != "system" for m in result) + user_msg = next(m for m in result if m["role"] == "user") + content = user_msg["content"] + assert "Go" in content + assert "What do I like?" in content + # Memory context should come before user text + assert content.index("Go") < content.index("What do I like?") + + def test_prepend_user_does_not_touch_system_message(self): + """PREPEND_USER mode leaves existing system message untouched.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import _inject_memories + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test", injection_mode=MemoryInjectionMode.PREPEND_USER) + + mock_result = MagicMock() + mock_result.text = "User likes TypeScript" + mock_result.type = "world" + + mock_client = MagicMock() + mock_client.recall.return_value = [mock_result] + + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + ] + + with patch("hindsight_litellm._get_client", return_value=mock_client): + result = _inject_memories(messages) + + system_msg = next(m for m in result if m["role"] == "system") + assert system_msg["content"] == "You are helpful." + assert "TypeScript" not in system_msg["content"] + + +class TestQueryField: + """Tests for the query field in HindsightCallSettings.""" + + def setup_method(self): + reset_config() + + def teardown_method(self): + cleanup() + + def test_defaults_query_used_when_no_custom_query(self): + """defaults.query is used as recall query when no hindsight_query kwarg given.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import _inject_memories + from hindsight_litellm.config import HindsightCallSettings + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test") + + # Manually set query on defaults + import hindsight_litellm.config as cfg + cfg._global_config.default_settings.query = "favorite language" + + mock_result = MagicMock() + mock_result.text = "User likes Rust" + mock_result.type = "world" + + mock_client = MagicMock() + mock_client.recall.return_value = [mock_result] + + messages = [{"role": "user", "content": "Tell me something"}] + + with patch("hindsight_litellm._get_client", return_value=mock_client): + _inject_memories(messages) + + call_kwargs = mock_client.recall.call_args[1] + assert call_kwargs["query"] == "favorite language" + + def test_custom_query_overrides_defaults_query(self): + """custom_query parameter overrides defaults.query.""" + from unittest.mock import MagicMock, patch + from hindsight_litellm import _inject_memories + + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="test") + + import hindsight_litellm.config as cfg + cfg._global_config.default_settings.query = "default query" + + mock_client = MagicMock() + mock_client.recall.return_value = [] + + messages = [{"role": "user", "content": "Hello"}] + + with patch("hindsight_litellm._get_client", return_value=mock_client): + _inject_memories(messages, custom_query="override query") + + call_kwargs = mock_client.recall.call_args[1] + assert call_kwargs["query"] == "override query" + + +class TestHindsightErrorConsistency: + """Tests that HindsightError (not ValueError) is raised for config errors.""" + + def setup_method(self): + reset_config() + + def teardown_method(self): + cleanup() + + def test_inject_memories_raises_hindsight_error_when_no_bank_id(self): + """_inject_memories raises HindsightError (not ValueError) when bank_id missing.""" + from hindsight_litellm import HindsightError, _inject_memories + from hindsight_litellm.callbacks import HindsightError as CallbackHindsightError + + configure(hindsight_api_url="http://localhost:8888") + # No bank_id set + + messages = [{"role": "user", "content": "Hello"}] + + with pytest.raises(HindsightError): + _inject_memories(messages) + + def test_callback_log_pre_api_call_raises_hindsight_error_when_no_bank_id(self): + """HindsightCallback.log_pre_api_call raises HindsightError (not ValueError).""" + from hindsight_litellm.callbacks import HindsightCallback, HindsightError + + configure(hindsight_api_url="http://localhost:8888") + # No bank_id set + + callback = HindsightCallback() + messages = [{"role": "user", "content": "Hello"}] + + with pytest.raises(HindsightError): + callback.log_pre_api_call("gpt-4o-mini", messages, {}) + + +class TestContextManagerFullRestore: + """Tests that hindsight_memory() restores ALL settings on exit.""" + + def setup_method(self): + cleanup() + + def teardown_method(self): + cleanup() + + def test_restores_sync_storage(self): + """hindsight_memory() restores sync_storage after exit.""" + configure(hindsight_api_url="http://localhost:8888", sync_storage=True) + set_defaults(bank_id="original") + + with hindsight_memory(bank_id="temp", hindsight_api_url="http://localhost:8888"): + assert get_config().sync_storage is False # default inside context + + assert get_config().sync_storage is True # restored + + def test_restores_tags(self): + """hindsight_memory() restores tags after exit.""" + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="original", tags=["env:prod"]) + + with hindsight_memory(bank_id="temp", hindsight_api_url="http://localhost:8888", tags=["env:test"]): + assert get_defaults().tags == ["env:test"] + + assert get_defaults().tags == ["env:prod"] + + def test_restores_recall_tags(self): + """hindsight_memory() restores recall_tags after exit.""" + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="original", recall_tags=["user:alice"], recall_tags_match="any_strict") + + with hindsight_memory(bank_id="temp", hindsight_api_url="http://localhost:8888"): + assert get_defaults().recall_tags is None + + assert get_defaults().recall_tags == ["user:alice"] + assert get_defaults().recall_tags_match == "any_strict" + + def test_restores_reflect_context(self): + """hindsight_memory() restores reflect_context after exit.""" + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="original", reflect_context="Be concise.") + + with hindsight_memory(bank_id="temp", hindsight_api_url="http://localhost:8888"): + assert get_defaults().reflect_context is None + + assert get_defaults().reflect_context == "Be concise." + + def test_restores_to_none_when_no_prior_config(self): + """hindsight_memory() resets config to None if none was set before.""" + assert get_config() is None + + with hindsight_memory(bank_id="temp"): + assert get_config() is not None + + assert get_config() is None + + def test_re_enables_if_was_enabled_before(self): + """hindsight_memory() re-enables if integration was enabled before entering.""" + configure(hindsight_api_url="http://localhost:8888") + set_defaults(bank_id="original") + enable() + assert is_enabled() is True + + with hindsight_memory(bank_id="temp", hindsight_api_url="http://localhost:8888"): + assert is_enabled() is True + + assert is_enabled() is True # still enabled after context From bcf5205e74684c380acbf54e301d5352c6142e90 Mon Sep 17 00:00:00 2001 From: DK09876 Date: Fri, 22 May 2026 08:39:03 -0700 Subject: [PATCH 4/4] fix(litellm): run ruff format and update test_config.py for no-default-bank-id behaviour - Run ruff format on __init__.py and wrappers.py to match CI lint expectations - test_config.py: update test_configure_with_no_arguments to assert bank_id is None (not DEFAULT_BANK_ID) and rename test_is_configured_true_with_defaults to test_is_configured_false_without_explicit_bank_id with corrected assertion, matching the removed DEFAULT_BANK_ID footgun Co-Authored-By: Claude Sonnet 4.6 --- .../litellm/hindsight_litellm/__init__.py | 24 +++++++++++-------- .../litellm/hindsight_litellm/wrappers.py | 4 +++- .../litellm/tests/test_config.py | 10 ++++---- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/hindsight-integrations/litellm/hindsight_litellm/__init__.py b/hindsight-integrations/litellm/hindsight_litellm/__init__.py index 298f302e1..7f1a0aee5 100644 --- a/hindsight-integrations/litellm/hindsight_litellm/__init__.py +++ b/hindsight-integrations/litellm/hindsight_litellm/__init__.py @@ -591,7 +591,9 @@ def _wrapped_completion(*args, **kwargs): final_messages = kwargs.get("messages", messages) if final_messages: if _is_streaming_response(response): - return _LiteLLMStreamWrapper(response, final_messages, model or "unknown", bank_id_override=bank_id_override) + return _LiteLLMStreamWrapper( + response, final_messages, model or "unknown", bank_id_override=bank_id_override + ) _store_conversation(final_messages, response, model or "unknown", bank_id_override=bank_id_override) return response @@ -642,7 +644,9 @@ async def _wrapped_acompletion(*args, **kwargs): final_messages = kwargs.get("messages", messages) if final_messages: if _is_streaming_response(response): - return _LiteLLMAsyncStreamWrapper(response, final_messages, model or "unknown", bank_id_override=bank_id_override) + return _LiteLLMAsyncStreamWrapper( + response, final_messages, model or "unknown", bank_id_override=bank_id_override + ) _store_conversation(final_messages, response, model or "unknown", bank_id_override=bank_id_override) return response @@ -1132,9 +1136,7 @@ def _store_conversation_from_text(conversation_text: str, model: str, bank_id_ov content_to_store = conversation_text effective_doc_id = defaults.effective_document_id if defaults else None if effective_doc_id: - existing_content = _get_existing_document_content( - effective_bank_id, effective_doc_id, config.verbose - ) + existing_content = _get_existing_document_content(effective_bank_id, effective_doc_id, config.verbose) if existing_content: content_to_store = f"{existing_content}\n\n{conversation_text}" @@ -1203,9 +1205,7 @@ def _store_conversation( try: content_to_store = conversation_text if effective_doc_id: - existing_content = _get_existing_document_content( - effective_bank_id, effective_doc_id, config.verbose - ) + existing_content = _get_existing_document_content(effective_bank_id, effective_doc_id, config.verbose) if existing_content: content_to_store = f"{existing_content}\n\n{conversation_text}" if config.verbose: @@ -1326,7 +1326,9 @@ def completion(*args, **kwargs): final_messages = kwargs.get("messages", messages) if final_messages: if _is_streaming_response(response): - return _LiteLLMStreamWrapper(response, final_messages, model or "unknown", bank_id_override=bank_id_override) + return _LiteLLMStreamWrapper( + response, final_messages, model or "unknown", bank_id_override=bank_id_override + ) _store_conversation(final_messages, response, model or "unknown", bank_id_override=bank_id_override) return response @@ -1406,7 +1408,9 @@ async def acompletion(*args, **kwargs): final_messages = kwargs.get("messages", messages) if final_messages: if _is_streaming_response(response): - return _LiteLLMAsyncStreamWrapper(response, final_messages, model or "unknown", bank_id_override=bank_id_override) + return _LiteLLMAsyncStreamWrapper( + response, final_messages, model or "unknown", bank_id_override=bank_id_override + ) _store_conversation(final_messages, response, model or "unknown", bank_id_override=bank_id_override) return response diff --git a/hindsight-integrations/litellm/hindsight_litellm/wrappers.py b/hindsight-integrations/litellm/hindsight_litellm/wrappers.py index 9dd2b7343..2926419ef 100644 --- a/hindsight-integrations/litellm/hindsight_litellm/wrappers.py +++ b/hindsight-integrations/litellm/hindsight_litellm/wrappers.py @@ -156,7 +156,9 @@ def recall( target_fact_types = fact_types or (defaults.fact_types if defaults else None) target_budget = budget or (defaults.budget if defaults else "mid") target_max_tokens = max_tokens or (defaults.max_memory_tokens if defaults else 4096) - target_include_entities = include_entities if include_entities is not None else (defaults.include_entities if defaults else True) + target_include_entities = ( + include_entities if include_entities is not None else (defaults.include_entities if defaults else True) + ) target_trace = trace if trace is not None else (defaults.trace if defaults else False) target_recall_tags = recall_tags or (defaults.recall_tags if defaults else None) target_recall_tags_match = recall_tags_match or (defaults.recall_tags_match if defaults else "any") diff --git a/hindsight-integrations/litellm/tests/test_config.py b/hindsight-integrations/litellm/tests/test_config.py index eae92b714..5e03d71fe 100644 --- a/hindsight-integrations/litellm/tests/test_config.py +++ b/hindsight-integrations/litellm/tests/test_config.py @@ -49,11 +49,11 @@ def teardown_method(self): reset_config() def test_configure_with_no_arguments(self): - """Test configure() with no arguments uses defaults.""" + """Test configure() with no arguments uses production URL and leaves bank_id unset.""" config = configure() assert config.hindsight_api_url == DEFAULT_HINDSIGHT_API_URL - assert config.bank_id == DEFAULT_BANK_ID + assert config.bank_id is None def test_configure_reads_api_key_from_env(self): """Test configure() reads API key from environment variable.""" @@ -81,10 +81,10 @@ def test_configure_explicit_values_override_defaults(self): assert config.bank_id == "custom-bank" assert config.api_key == "custom-key" - def test_is_configured_true_with_defaults(self): - """Test is_configured() returns True with default config.""" + def test_is_configured_false_without_explicit_bank_id(self): + """Test is_configured() returns False when configure() is called without bank_id.""" configure() - assert is_configured() is True + assert is_configured() is False def test_is_configured_false_when_not_configured(self): """Test is_configured() returns False when not configured."""