diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index 31fcfaee2..1904d74b0 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -1678,13 +1678,14 @@ def compute_metadata(): eprint(f"Failed to load prompt, attempting to fall back to cache: {server_error}") try: if id: - return _state._prompt_cache.get(id=id) + return _state._prompt_cache.get(id=id, org_id=_state.org_id) else: return _state._prompt_cache.get( slug, version=str(version) if version else "latest", project_id=project_id, project_name=project, + org_id=_state.org_id, ) except Exception as cache_error: if id: @@ -1714,6 +1715,7 @@ def compute_metadata(): _state._prompt_cache.set( prompt, id=id, + org_id=_state.org_id, ) elif slug: _state._prompt_cache.set( @@ -1722,6 +1724,7 @@ def compute_metadata(): version=str(version) if version else "latest", project_id=project_id, project_name=project, + org_id=_state.org_id, ) except Exception as e: eprint(f"Failed to store prompt in cache: {e}") diff --git a/py/src/braintrust/prompt_cache/prompt_cache.py b/py/src/braintrust/prompt_cache/prompt_cache.py index 273138497..ef33bc046 100644 --- a/py/src/braintrust/prompt_cache/prompt_cache.py +++ b/py/src/braintrust/prompt_cache/prompt_cache.py @@ -6,7 +6,7 @@ 2. A persistent disk-based cache that serves as a backing store This allows for efficient prompt retrieval while maintaining persistence across sessions. -The cache is keyed by project identifier (ID or name), prompt slug, and version. +The cache is keyed by organization ID, project identifier (ID or name), prompt slug, and version. """ @@ -15,15 +15,21 @@ def _create_cache_key( + org_id: str | None, project_id: str | None, project_name: str | None, slug: str | None, version: str = "latest", id: str | None = None, ) -> str: - """Creates a unique cache key from project identifier, slug and version, or from ID.""" + """Creates a unique cache key from org ID, project identifier, slug and version, or from prompt ID. + + The org_id is included to ensure cache isolation between organizations. Without it, + two organizations with the same project name and prompt slug could get each other's + cached prompts, leading to incorrect prompt retrieval. + """ if id: - # When caching by ID, we don't need project or slug + # When caching by ID, we don't need project or slug (IDs are globally unique) return f"id:{id}" prefix = project_id or project_name @@ -31,6 +37,10 @@ def _create_cache_key( raise ValueError("Either project_id or project_name must be provided") if not slug: raise ValueError("Slug must be provided when not using ID") + + # Include org_id in cache key if available to ensure cross-org isolation + if org_id: + return f"{org_id}:{prefix}:{slug}:{version}" return f"{prefix}:{slug}:{version}" @@ -65,6 +75,7 @@ def get( project_id: str | None = None, project_name: str | None = None, id: str | None = None, + org_id: str | None = None, ) -> prompt.PromptSchema: """ Retrieve a prompt from the cache. @@ -75,6 +86,7 @@ def get( project_id: The ID of the project containing the prompt. project_name: The name of the project containing the prompt. id: The ID of a specific prompt. If provided, slug and project parameters are ignored. + org_id: The ID of the organization. Used to ensure cache isolation between orgs. Returns: The cached Prompt object. @@ -83,7 +95,7 @@ def get( ValueError: If neither project_id nor project_name is provided (when not using id). KeyError: If the prompt is not found in the cache. """ - cache_key = _create_cache_key(project_id, project_name, slug, version, id) + cache_key = _create_cache_key(org_id, project_id, project_name, slug, version, id) # First check memory cache. try: @@ -111,6 +123,7 @@ def set( project_id: str | None = None, project_name: str | None = None, id: str | None = None, + org_id: str | None = None, ) -> None: """ Store a prompt in the cache. @@ -122,12 +135,13 @@ def set( project_id: The ID of the project containing the prompt. project_name: The name of the project containing the prompt. id: The ID of a specific prompt. If provided, slug and project parameters are ignored. + org_id: The ID of the organization. Used to ensure cache isolation between orgs. Raises: ValueError: If neither project_id nor project_name is provided (when not using id). RuntimeError: If there is an error writing to the disk cache. """ - cache_key = _create_cache_key(project_id, project_name, slug, version, id) + cache_key = _create_cache_key(org_id, project_id, project_name, slug, version, id) # Update memory cache. self.memory_cache.set(cache_key, value) diff --git a/py/src/braintrust/prompt_cache/test_prompt_cache.py b/py/src/braintrust/prompt_cache/test_prompt_cache.py index 0e0d70c8f..1468e8b3f 100644 --- a/py/src/braintrust/prompt_cache/test_prompt_cache.py +++ b/py/src/braintrust/prompt_cache/test_prompt_cache.py @@ -188,6 +188,112 @@ def test_id_cache_with_disk_persistence(self): result = self.cache.get(id=prompt_id) self.assertEqual(result.as_dict(), self.test_prompt.as_dict()) + def test_handle_different_orgs_with_same_project_and_slug(self): + """Test that prompts from different orgs with same project/slug are isolated. + + This test verifies the fix for the cross-org cache collision bug where + two organizations with the same project name and prompt slug could get + each other's cached prompts. + """ + org1_prompt = prompt.PromptSchema( + id="org1-prompt-id", + project_id="shared-project-id", + _xact_id="111", + name="shared-prompt", + slug="shared-prompt", + description="This is Org 1's prompt", + prompt_data=prompt.PromptData(), + tags=None, + ) + + org2_prompt = prompt.PromptSchema( + id="org2-prompt-id", + project_id="shared-project-id", + _xact_id="222", + name="shared-prompt", + slug="shared-prompt", + description="This is Org 2's prompt", + prompt_data=prompt.PromptData(), + tags=None, + ) + + # Store prompts from different orgs with same project_name and slug + self.cache.set( + org1_prompt, + slug="shared-prompt", + version="latest", + project_name="MyProject", + org_id="org-111", + ) + self.cache.set( + org2_prompt, + slug="shared-prompt", + version="latest", + project_name="MyProject", + org_id="org-222", + ) + + # Retrieve each org's prompt - should get the correct one + result1 = self.cache.get( + slug="shared-prompt", + version="latest", + project_name="MyProject", + org_id="org-111", + ) + result2 = self.cache.get( + slug="shared-prompt", + version="latest", + project_name="MyProject", + org_id="org-222", + ) + + # Verify org isolation - each org gets their own prompt + self.assertEqual(result1.description, "This is Org 1's prompt") + self.assertEqual(result2.description, "This is Org 2's prompt") + self.assertEqual(result1.id, "org1-prompt-id") + self.assertEqual(result2.id, "org2-prompt-id") + + def test_org_id_isolation_with_disk_cache(self): + """Test that org_id isolation works after memory eviction (via disk cache).""" + org1_prompt = prompt.PromptSchema( + id="disk-org1-id", + project_id="project", + _xact_id="111", + name="prompt", + slug="prompt", + description="Org 1 disk prompt", + prompt_data=prompt.PromptData(), + tags=None, + ) + + org2_prompt = prompt.PromptSchema( + id="disk-org2-id", + project_id="project", + _xact_id="222", + name="prompt", + slug="prompt", + description="Org 2 disk prompt", + prompt_data=prompt.PromptData(), + tags=None, + ) + + # Store org1's prompt + self.cache.set(org1_prompt, slug="prompt", version="v1", project_name="proj", org_id="org1") + + # Fill memory cache to evict org1's prompt (memory cache max_size=2) + self.cache.set(self.test_prompt, slug="filler1", version="v1", project_id="123") + self.cache.set(self.test_prompt, slug="filler2", version="v1", project_id="123") + + # Store org2's prompt (should not overwrite org1's cached prompt) + self.cache.set(org2_prompt, slug="prompt", version="v1", project_name="proj", org_id="org2") + + # Both should be retrievable with correct isolation + result1 = self.cache.get(slug="prompt", version="v1", project_name="proj", org_id="org1") + result2 = self.cache.get(slug="prompt", version="v1", project_name="proj", org_id="org2") + + self.assertEqual(result1.description, "Org 1 disk prompt") + self.assertEqual(result2.description, "Org 2 disk prompt") + if __name__ == "__main__": unittest.main()