diff --git a/client/webui/frontend/src/lib/components/chat/artifact/ArtifactDeleteAllDialog.tsx b/client/webui/frontend/src/lib/components/chat/artifact/ArtifactDeleteAllDialog.tsx index b876c4a892..c7cae5e98a 100644 --- a/client/webui/frontend/src/lib/components/chat/artifact/ArtifactDeleteAllDialog.tsx +++ b/client/webui/frontend/src/lib/components/chat/artifact/ArtifactDeleteAllDialog.tsx @@ -18,17 +18,20 @@ export const ArtifactDeleteAllDialog: React.FC = () => { return null; } - const hasProjectArtifacts = artifacts.some(artifact => artifact.source === "project"); - const projectArtifactsCount = artifacts.filter(artifact => artifact.source === "project").length; - const regularArtifactsCount = artifacts.length - projectArtifactsCount; + // Check for read-only artifacts (project or agent_default) + const isReadOnlyArtifact = (artifact: { source?: string }) => artifact.source === "project" || artifact.source === "agent_default"; + + const hasReadOnlyArtifacts = artifacts.some(isReadOnlyArtifact); + const readOnlyArtifactsCount = artifacts.filter(isReadOnlyArtifact).length; + const regularArtifactsCount = artifacts.length - readOnlyArtifactsCount; const getDescription = () => { - if (hasProjectArtifacts && regularArtifactsCount === 0) { - // All are project artifacts - return `${artifacts.length === 1 ? "This file" : `All ${artifacts.length} files`} will be removed from this chat session. ${artifacts.length === 1 ? "The file" : "These files"} will remain in ${artifacts.length === 1 ? "the" : "their"} project${artifacts.length === 1 ? "" : "s"}.`; - } else if (hasProjectArtifacts && regularArtifactsCount > 0) { - // Mixed: some project, some regular - return `${regularArtifactsCount} ${regularArtifactsCount === 1 ? "file" : "files"} will be permanently deleted. ${projectArtifactsCount} project ${projectArtifactsCount === 1 ? "file" : "files"} will be removed from this chat but will remain in ${projectArtifactsCount === 1 ? "the" : "their"} project${projectArtifactsCount === 1 ? "" : "s"}.`; + if (hasReadOnlyArtifacts && regularArtifactsCount === 0) { + // All are read-only artifacts (project or agent_default) + return `${artifacts.length === 1 ? "This file" : `All ${artifacts.length} files`} will be removed from this chat session. ${artifacts.length === 1 ? "The file" : "These files"} will remain available as ${artifacts.length === 1 ? "a default" : "defaults"}.`; + } else if (hasReadOnlyArtifacts && regularArtifactsCount > 0) { + // Mixed: some read-only, some regular + return `${regularArtifactsCount} ${regularArtifactsCount === 1 ? "file" : "files"} will be permanently deleted. ${readOnlyArtifactsCount} read-only ${readOnlyArtifactsCount === 1 ? "file" : "files"} will be removed from this chat but will remain available.`; } else { // All are regular artifacts return `${artifacts.length === 1 ? "One file" : `All ${artifacts.length} files`} will be permanently deleted.`; diff --git a/client/webui/frontend/src/lib/components/chat/artifact/ArtifactPanel.tsx b/client/webui/frontend/src/lib/components/chat/artifact/ArtifactPanel.tsx index 7ddb5e0349..14666f8cab 100644 --- a/client/webui/frontend/src/lib/components/chat/artifact/ArtifactPanel.tsx +++ b/client/webui/frontend/src/lib/components/chat/artifact/ArtifactPanel.tsx @@ -34,9 +34,9 @@ export const ArtifactPanel: React.FC = () => { return artifacts ? [...artifacts].sort(sortFunctions[sortOption]) : []; }, [artifacts, artifactsLoading, sortOption]); - // Check if there are any deletable artifacts (not from projects) + // Check if there are any deletable artifacts (not from projects or agent defaults) const hasDeletableArtifacts = useMemo(() => { - return sortedArtifacts.some(artifact => artifact.source !== "project"); + return sortedArtifacts.some(artifact => artifact.source !== "project" && artifact.source !== "agent_default"); }, [sortedArtifacts]); const header = useMemo(() => { @@ -106,7 +106,7 @@ export const ArtifactPanel: React.FC = () => { isPreview={true} isExpanded={isPreviewInfoExpanded} setIsExpanded={setIsPreviewInfoExpanded} - onDelete={previewArtifact.source === "project" ? undefined : () => openDeleteModal(previewArtifact)} + onDelete={previewArtifact.source === "project" || previewArtifact.source === "agent_default" ? undefined : () => openDeleteModal(previewArtifact)} onDownload={() => onDownload(previewArtifact)} /> @@ -126,7 +126,7 @@ export const ArtifactPanel: React.FC = () => {
Type: -
{previewArtifact.mime_type || 'Unknown'}
+
{previewArtifact.mime_type || "Unknown"}
diff --git a/client/webui/frontend/src/lib/components/chat/file/ArtifactMessage.tsx b/client/webui/frontend/src/lib/components/chat/file/ArtifactMessage.tsx index 88cc43806f..24973392e0 100644 --- a/client/webui/frontend/src/lib/components/chat/file/ArtifactMessage.tsx +++ b/client/webui/frontend/src/lib/components/chat/file/ArtifactMessage.tsx @@ -52,8 +52,10 @@ export const ArtifactMessage: React.FC = props => { const context = props.context || "chat"; const isStreaming = props.isStreaming; - // Check if this artifact is from a project (should not be deletable) + // Check if this artifact is from a project or agent default (should not be deletable) const isProjectArtifact = artifact?.source === "project"; + const isAgentDefaultArtifact = artifact?.source === "agent_default"; + const isReadOnlyArtifact = isProjectArtifact || isAgentDefaultArtifact; // Extract version from URI if available const version = useMemo(() => { @@ -326,8 +328,8 @@ export const ArtifactMessage: React.FC = props => { return { onInfo: handleInfoClick, onDownload: props.status === "completed" ? handleDownloadClick : undefined, - // Hide delete button for artifacts with source="project" (they came from project files) - onDelete: artifact && props.status === "completed" && !isProjectArtifact ? handleDeleteClick : undefined, + // Hide delete button for artifacts with source="project" or "agent_default" (read-only artifacts) + onDelete: artifact && props.status === "completed" && !isReadOnlyArtifact ? handleDeleteClick : undefined, }; } else { // In chat context, show preview, download, and info actions @@ -338,7 +340,7 @@ export const ArtifactMessage: React.FC = props => { onInfo: handleInfoClick, }; } - }, [props.status, context, handleDownloadClick, artifact, handleDeleteClick, handleInfoClick, handlePreviewClick, isProjectArtifact]); + }, [props.status, context, handleDownloadClick, artifact, handleDeleteClick, handleInfoClick, handlePreviewClick, isReadOnlyArtifact]); // Get description from global artifacts instead of message parts const artifactFromGlobal = useMemo(() => artifacts.find(art => art.filename === props.name), [artifacts, props.name]); diff --git a/client/webui/frontend/src/lib/hooks/useArtifacts.ts b/client/webui/frontend/src/lib/hooks/useArtifacts.ts index 8f2744ffd0..89d3a069ba 100644 --- a/client/webui/frontend/src/lib/hooks/useArtifacts.ts +++ b/client/webui/frontend/src/lib/hooks/useArtifacts.ts @@ -24,7 +24,7 @@ const isIntermediateWebContentArtifact = (filename: string | undefined): boolean return filename.startsWith("web_content_"); }; -export const useArtifacts = (sessionId?: string): UseArtifactsReturn => { +export const useArtifacts = (sessionId?: string, agentName?: string): UseArtifactsReturn => { const { activeProject } = useProjectContext(); const [artifacts, setArtifacts] = useState([]); const [isLoading, setIsLoading] = useState(true); @@ -36,17 +36,30 @@ export const useArtifacts = (sessionId?: string): UseArtifactsReturn => { try { let endpoint: string; + const params = new URLSearchParams(); if (sessionId && sessionId.trim() && sessionId !== "null" && sessionId !== "undefined") { endpoint = `/api/v1/artifacts/${sessionId}`; } else if (activeProject?.id) { - endpoint = `/api/v1/artifacts/null?project_id=${activeProject.id}`; + endpoint = `/api/v1/artifacts/null`; + params.append("project_id", activeProject.id); } else { setArtifacts([]); setIsLoading(false); return; } + // Add agent_name parameter to include agent's default artifacts + if (agentName) { + params.append("agent_name", agentName); + } + + // Append query parameters if any + const queryString = params.toString(); + if (queryString) { + endpoint += `?${queryString}`; + } + const data: ArtifactInfo[] = await api.webui.get(endpoint); // Filter out intermediate web content artifacts from deep research const filteredData = data.filter(artifact => !isIntermediateWebContentArtifact(artifact.filename)); @@ -62,7 +75,7 @@ export const useArtifacts = (sessionId?: string): UseArtifactsReturn => { } finally { setIsLoading(false); } - }, [sessionId, activeProject?.id]); + }, [sessionId, activeProject?.id, agentName]); useEffect(() => { fetchArtifacts(); diff --git a/client/webui/frontend/src/lib/providers/ChatProvider.tsx b/client/webui/frontend/src/lib/providers/ChatProvider.tsx index 94c61a4d10..8f7eef8d0a 100644 --- a/client/webui/frontend/src/lib/providers/ChatProvider.tsx +++ b/client/webui/frontend/src/lib/providers/ChatProvider.tsx @@ -113,7 +113,8 @@ export const ChatProvider: React.FC = ({ children }) => { const { agents, agentNameMap: agentNameDisplayNameMap, error: agentsError, isLoading: agentsLoading, refetch: agentsRefetch } = useAgentCards(); // Chat Side Panel State - const { artifacts, isLoading: artifactsLoading, refetch: artifactsRefetch, setArtifacts } = useArtifacts(sessionId); + // Pass selectedAgentName to include agent's default artifacts when agent changes + const { artifacts, isLoading: artifactsLoading, refetch: artifactsRefetch, setArtifacts } = useArtifacts(sessionId, selectedAgentName); // Title Generation const { generateTitle } = useTitleGeneration(); diff --git a/src/solace_agent_mesh/agent/adk/services.py b/src/solace_agent_mesh/agent/adk/services.py index 6872a703b5..e0084e043e 100644 --- a/src/solace_agent_mesh/agent/adk/services.py +++ b/src/solace_agent_mesh/agent/adk/services.py @@ -46,12 +46,22 @@ TestInMemoryArtifactService = None +# Constants for agent-level default artifacts +AGENT_DEFAULTS_USER_ID = "__agent_defaults__" + + class ScopedArtifactServiceWrapper(BaseArtifactService): """ A wrapper for an artifact service that transparently applies a configured scope. This ensures all artifact operations respect either 'namespace' or 'app' scoping without requiring changes at the call site. It dynamically checks the component's configuration on each call to support test-specific overrides. + + Additionally, this wrapper supports agent-level default artifacts that are: + - Loaded at agent startup from the `default_artifacts` configuration + - Stored in a special scope using AGENT_DEFAULTS_USER_ID + - Accessible to all users with access to the agent (read-only) + - Used as a fallback when an artifact is not found in the user's session """ def __init__( @@ -68,6 +78,8 @@ def __init__( """ self.wrapped_service = wrapped_service self.component = component + # Cache of default artifact filenames for quick lookup + self._default_artifact_filenames: Optional[set] = None def _get_scoped_app_name(self, app_name: str) -> str: """ @@ -86,6 +98,28 @@ def _get_scoped_app_name(self, app_name: str) -> str: # typically the agent_name or gateway_id. return app_name + def _has_default_artifacts(self) -> bool: + """Check if the agent has default artifacts configured.""" + default_artifacts = self.component.get_config("default_artifacts", []) + return bool(default_artifacts) + + def _get_default_artifact_filenames(self) -> set: + """Get the set of default artifact filenames for quick lookup.""" + if self._default_artifact_filenames is None: + default_artifacts = self.component.get_config("default_artifacts", []) + self._default_artifact_filenames = { + artifact.get("filename") for artifact in default_artifacts if artifact.get("filename") + } + return self._default_artifact_filenames + + def _is_default_artifact(self, filename: str) -> bool: + """Check if a filename corresponds to a default artifact.""" + return filename in self._get_default_artifact_filenames() + + def _get_agent_name(self) -> str: + """Get the agent name from the component.""" + return self.component.get_config("agent_name", self.component.agent_name) + @override async def save_artifact( self, @@ -97,6 +131,28 @@ async def save_artifact( artifact: adk_types.Part, ) -> int: scoped_app_name = self._get_scoped_app_name(app_name) + + # Allow saving to agent defaults scope (used during agent initialization) + if user_id == AGENT_DEFAULTS_USER_ID: + return await self.wrapped_service.save_artifact( + app_name=scoped_app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=artifact, + ) + + # For regular users, check if they're trying to overwrite a default artifact + # Users can create their own version of a default artifact in their session + # (this allows overrides), but we log a warning for visibility + if self._is_default_artifact(filename): + log.debug( + "User '%s' is saving artifact '%s' which shadows a default artifact. " + "The user's version will take precedence in their session.", + user_id, + filename, + ) + return await self.wrapped_service.save_artifact( app_name=scoped_app_name, user_id=user_id, @@ -116,28 +172,104 @@ async def load_artifact( version: Optional[int] = None, ) -> Optional[adk_types.Part]: scoped_app_name = self._get_scoped_app_name(app_name) - return await self.wrapped_service.load_artifact( + + # First, try to load from the user's session + result = await self.wrapped_service.load_artifact( app_name=scoped_app_name, user_id=user_id, session_id=session_id, filename=filename, version=version, ) + + if result is not None: + return result + + # If not found and we have default artifacts configured, try the agent defaults + if self._has_default_artifacts() and user_id != AGENT_DEFAULTS_USER_ID: + agent_name = self._get_agent_name() + log.debug( + "Artifact '%s' not found in user session, checking agent defaults for agent '%s'.", + filename, + agent_name, + ) + result = await self.wrapped_service.load_artifact( + app_name=scoped_app_name, + user_id=AGENT_DEFAULTS_USER_ID, + session_id=agent_name, + filename=filename, + version=version, + ) + if result is not None: + log.debug( + "Loaded artifact '%s' from agent defaults for agent '%s'.", + filename, + agent_name, + ) + + return result @override async def list_artifact_keys( self, *, app_name: str, user_id: str, session_id: str ) -> List[str]: scoped_app_name = self._get_scoped_app_name(app_name) - return await self.wrapped_service.list_artifact_keys( + + # Get user's session artifacts + user_artifacts = await self.wrapped_service.list_artifact_keys( app_name=scoped_app_name, user_id=user_id, session_id=session_id ) + + # If we have default artifacts and this is not the defaults user, include them + if self._has_default_artifacts() and user_id != AGENT_DEFAULTS_USER_ID: + agent_name = self._get_agent_name() + default_artifacts = await self.wrapped_service.list_artifact_keys( + app_name=scoped_app_name, + user_id=AGENT_DEFAULTS_USER_ID, + session_id=agent_name, + ) + + # Merge lists (user artifacts take precedence, so we use a set) + all_artifacts = list(set(user_artifacts + default_artifacts)) + return all_artifacts + + return user_artifacts @override async def delete_artifact( self, *, app_name: str, user_id: str, session_id: str, filename: str ) -> None: scoped_app_name = self._get_scoped_app_name(app_name) + + # Prevent users from deleting default artifacts + if user_id != AGENT_DEFAULTS_USER_ID and self._is_default_artifact(filename): + # Check if the user has their own version of this artifact + user_artifact = await self.wrapped_service.load_artifact( + app_name=scoped_app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + if user_artifact is None: + # User is trying to delete a default artifact they don't own + log.warning( + "User '%s' attempted to delete default artifact '%s'. " + "Default artifacts are read-only.", + user_id, + filename, + ) + raise PermissionError( + f"Cannot delete default artifact '{filename}'. " + "Default artifacts are read-only." + ) + # User has their own version, allow deletion of their version + log.debug( + "User '%s' is deleting their own version of artifact '%s' " + "(default artifact will still be accessible).", + user_id, + filename, + ) + await self.wrapped_service.delete_artifact( app_name=scoped_app_name, user_id=user_id, diff --git a/src/solace_agent_mesh/agent/sac/app.py b/src/solace_agent_mesh/agent/sac/app.py index f416384bff..f16e28e4a3 100644 --- a/src/solace_agent_mesh/agent/sac/app.py +++ b/src/solace_agent_mesh/agent/sac/app.py @@ -197,6 +197,36 @@ class McpProcessingConfig(SamConfigBase): ) +class DefaultArtifactConfig(SamConfigBase): + """Configuration for a default artifact to be pre-loaded when the agent starts. + + Default artifacts are loaded at agent startup and made available to all users + with access to the agent. They are read-only for users and stored in a special + agent-level scope that is accessible from all sessions. + """ + + path: str = Field( + ..., + description="Path to the artifact file. Supports local filesystem paths, " + "environment variable substitution (e.g., ${SAM_PROJECT_ROOT}/data/file.pdf), " + "and remote sources (s3://, https://).", + ) + filename: str = Field( + ..., + description="Filename to use when storing the artifact. This is the name " + "that will be used to reference the artifact in tools and prompts.", + ) + description: Optional[str] = Field( + default=None, + description="Human-readable description of the artifact's purpose and contents.", + ) + mime_type: Optional[str] = Field( + default=None, + description="MIME type of the artifact. If not specified, it will be " + "auto-detected from the file extension.", + ) + + class ArtifactServiceConfig(SamConfigBase): """Configuration for the ADK Artifact Service.""" @@ -457,6 +487,13 @@ class SamAgentAppConfig(SamConfigBase): default_factory=McpProcessingConfig, description="Configuration for intelligent processing of MCP tool responses.", ) + default_artifacts: List[DefaultArtifactConfig] = Field( + default_factory=list, + description="List of artifacts to pre-load when the agent starts. " + "These artifacts are available to all users with access to the agent " + "and are read-only. They are stored in a special agent-level scope " + "that is accessible from all sessions.", + ) class SamAgentApp(SamAppBase): diff --git a/src/solace_agent_mesh/agent/sac/component.py b/src/solace_agent_mesh/agent/sac/component.py index c01e0685c7..9cb1a1de45 100644 --- a/src/solace_agent_mesh/agent/sac/component.py +++ b/src/solace_agent_mesh/agent/sac/component.py @@ -3199,6 +3199,9 @@ async def _perform_async_init(self): ) self.runner = initialize_adk_runner(self) + # Load default artifacts if configured + await self._load_default_artifacts() + log.info("%s Populating agent card tool manifest...", self.log_identifier) tool_manifest = [] for tool in loaded_tools: @@ -3311,6 +3314,195 @@ async def _perform_async_init(self): ) raise e + async def _load_default_artifacts(self): + """ + Loads default artifacts configured for this agent into the artifact service. + + Default artifacts are stored with a special user_id marker (AGENT_DEFAULTS_USER_ID) + so they can be accessed by all users of this agent as a fallback when the user + doesn't have their own version of the artifact. + + This method is called during async initialization and reads files from the + configured paths, storing them in the artifact service. + """ + from .app import DefaultArtifactConfig + from ...agent.adk.services import AGENT_DEFAULTS_USER_ID + from ...agent.utils.artifact_helpers import save_artifact_with_metadata + from datetime import datetime, timezone + import mimetypes + import os + + log.info( + "%s Checking for default artifacts configuration...", + self.log_identifier, + ) + + default_artifacts_config = self.get_config("default_artifacts", []) + + log.info( + "%s default_artifacts_config = %s (type: %s)", + self.log_identifier, + default_artifacts_config, + type(default_artifacts_config).__name__, + ) + + if not default_artifacts_config: + log.info( + "%s No default artifacts configured for this agent.", + self.log_identifier, + ) + return + + if not self.artifact_service: + log.warning( + "%s Artifact service not available. Cannot load default artifacts.", + self.log_identifier, + ) + return + + log.info( + "%s Loading %d default artifact(s) for agent '%s'...", + self.log_identifier, + len(default_artifacts_config), + self.agent_name, + ) + + loaded_count = 0 + for artifact_config in default_artifacts_config: + try: + # Extract configuration + if isinstance(artifact_config, dict): + file_path = artifact_config.get("path") + filename = artifact_config.get("filename") + mime_type = artifact_config.get("mime_type") + description = artifact_config.get("description") + else: + # Pydantic model + file_path = artifact_config.path + filename = artifact_config.filename + mime_type = artifact_config.mime_type + description = artifact_config.description + + if not file_path: + log.warning( + "%s Default artifact config missing 'path'. Skipping.", + self.log_identifier, + ) + continue + + # Resolve the file path (support relative paths from config directory) + if not os.path.isabs(file_path): + # Try to resolve relative to the config file location + base_path = self.get_config("base_path", ".") + file_path = os.path.join(base_path, file_path) + + if not os.path.exists(file_path): + log.error( + "%s Default artifact file not found: %s", + self.log_identifier, + file_path, + ) + continue + + # Use the filename from config or derive from path + if not filename: + filename = os.path.basename(file_path) + + # Detect MIME type if not specified + if not mime_type: + mime_type, _ = mimetypes.guess_type(file_path) + if not mime_type: + mime_type = "application/octet-stream" + + # Read the file content + with open(file_path, "rb") as f: + file_content = f.read() + + # Check if this artifact already exists to avoid creating duplicate versions on restart + try: + existing_versions = await self.artifact_service.list_versions( + app_name=self.agent_name, + user_id=AGENT_DEFAULTS_USER_ID, + session_id=self.agent_name, + filename=filename, + ) + if existing_versions: + log.info( + "%s Default artifact '%s' already exists (v%s), skipping to avoid duplicate versions.", + self.log_identifier, + filename, + existing_versions[-1] if existing_versions else "?", + ) + loaded_count += 1 # Count as loaded since it exists + continue + except Exception as e: + # If list_versions fails (e.g., artifact doesn't exist), proceed with saving + log.debug( + "%s Could not check existing versions for '%s': %s. Proceeding with save.", + self.log_identifier, + filename, + e, + ) + + # Build metadata dictionary for the default artifact + metadata_dict = { + "source": "agent_default", # Mark as agent default + } + if description: + metadata_dict["description"] = description + + # Save the artifact WITH metadata using save_artifact_with_metadata + # This ensures the .metadata.json file is also created, which is required + # for the artifact to be properly listed and loaded later + save_result = await save_artifact_with_metadata( + artifact_service=self.artifact_service, + app_name=self.agent_name, + user_id=AGENT_DEFAULTS_USER_ID, + session_id=self.agent_name, # Use agent name as session for defaults + filename=filename, + content_bytes=file_content, + mime_type=mime_type, + metadata_dict=metadata_dict, + timestamp=datetime.now(timezone.utc), + suppress_visualization_signal=True, # Don't send signals during init + ) + + if save_result.get("status") in ["success", "partial_success"]: + version = save_result.get("data_version") + loaded_count += 1 + log.info( + "%s Loaded default artifact '%s' (v%s) from '%s' [%s]%s", + self.log_identifier, + filename, + version, + file_path, + mime_type, + f" - {description}" if description else "", + ) + else: + log.error( + "%s Failed to save default artifact '%s': %s", + self.log_identifier, + filename, + save_result.get("message", "Unknown error"), + ) + + except Exception as e: + log.exception( + "%s Error loading default artifact from config %s: %s", + self.log_identifier, + artifact_config, + e, + ) + + log.info( + "%s Successfully loaded %d/%d default artifact(s) for agent '%s'.", + self.log_identifier, + loaded_count, + len(default_artifacts_config), + self.agent_name, + ) + def cleanup(self): """Clean up resources on component shutdown.""" log.info("%s Cleaning up A2A ADK Host Component.", self.log_identifier) diff --git a/src/solace_agent_mesh/agent/tools/builtin_artifact_tools.py b/src/solace_agent_mesh/agent/tools/builtin_artifact_tools.py index 72691c7e98..26ef5f8f23 100644 --- a/src/solace_agent_mesh/agent/tools/builtin_artifact_tools.py +++ b/src/solace_agent_mesh/agent/tools/builtin_artifact_tools.py @@ -26,6 +26,10 @@ is_filename_safe, METADATA_SUFFIX, DEFAULT_SCHEMA_MAX_KEYS, + AGENT_DEFAULTS_USER_ID, + _get_agent_default_artifacts, + get_latest_artifact_version, + load_artifact_with_fallback_to_defaults, ) from ...common.utils.embeds import ( evaluate_embed, @@ -196,6 +200,7 @@ async def list_artifacts(tool_context: ToolContext = None) -> Dict[str, Any]: """ Lists all available data artifact filenames and their versions for the current session. Includes a summary of the latest version's metadata for each artifact. + Also includes agent-level default artifacts if configured. Args: tool_context: The context provided by the ADK framework. @@ -214,6 +219,12 @@ async def list_artifacts(tool_context: ToolContext = None) -> Dict[str, Any]: app_name = tool_context._invocation_context.app_name user_id = tool_context._invocation_context.user_id session_id = get_original_session_id(tool_context._invocation_context) + + # Get agent name for default artifacts lookup + agent = getattr(tool_context._invocation_context, "agent", None) + host_component = getattr(agent, "host_component", None) if agent else None + agent_name = host_component.agent_name if host_component else None + list_keys_method = getattr(artifact_service, "list_artifact_keys") all_keys = await list_keys_method( app_name=app_name, user_id=user_id, session_id=session_id @@ -346,8 +357,78 @@ async def list_artifacts(tool_context: ToolContext = None) -> Dict[str, Any]: } ) processed_data_files.add(filename) + + # Include agent-level default artifacts if agent_name is available + if agent_name: + log.debug( + "%s Fetching agent default artifacts for agent '%s'.", + log_identifier, + agent_name, + ) + try: + default_artifacts = await _get_agent_default_artifacts( + artifact_service=artifact_service, + app_name=app_name, + agent_name=agent_name, + exclude_filenames=processed_data_files, + log_prefix=log_identifier, + ) + + # Convert ArtifactInfo objects to the response format + for artifact_info in default_artifacts: + # Get versions for the default artifact + default_versions = [] + try: + default_versions_list = await artifact_service.list_versions( + app_name=agent_name, # Default artifacts use agent_name as app_name + user_id=AGENT_DEFAULTS_USER_ID, + session_id=agent_name, + filename=artifact_info.filename, + ) + default_versions = list(default_versions_list) if default_versions_list else [] + except Exception as ver_err: + log.warning( + "%s Failed to list versions for default artifact '%s': %s", + log_identifier, + artifact_info.filename, + ver_err, + ) + default_versions = [artifact_info.version] if artifact_info.version is not None else [0] + + metadata_summary = { + "description": artifact_info.description, + "source": "agent_default", # Mark as agent default + "type": artifact_info.mime_type, + "size": artifact_info.size, + } + # Remove None values + metadata_summary = {k: v for k, v in metadata_summary.items() if v is not None} + + response_files.append( + { + "filename": artifact_info.filename, + "versions": default_versions, + "metadata_summary": metadata_summary, + "is_agent_default": True, # Flag to indicate this is a default artifact + } + ) + processed_data_files.add(artifact_info.filename) + + log.info( + "%s Added %d agent default artifacts for agent '%s'.", + log_identifier, + len(default_artifacts), + agent_name, + ) + except Exception as default_err: + log.warning( + "%s Failed to fetch agent default artifacts: %s", + log_identifier, + default_err, + ) + log.info( - "%s Found %d data artifacts for session %s.", + "%s Found %d total artifacts for session %s (including defaults).", log_identifier, len(response_files), session_id, @@ -369,6 +450,7 @@ async def load_artifact( """ Loads the content or metadata of a specific artifact version. Early-stage embeds in the filename argument are resolved. + Falls back to agent-level default artifacts if not found in user session. If load_metadata_only is True, loads the full metadata dictionary. Otherwise, loads text content (potentially truncated) or binary metadata summary. @@ -411,6 +493,9 @@ async def load_artifact( session_id = get_original_session_id(tool_context._invocation_context) agent = getattr(tool_context._invocation_context, "agent", None) host_component = getattr(agent, "host_component", None) if agent else None + agent_name = host_component.agent_name if host_component else None + + # First, try to load from user's session result = await load_artifact_content_or_metadata( artifact_service=artifact_service, app_name=app_name, @@ -424,6 +509,37 @@ async def load_artifact( component=host_component, log_identifier_prefix="[BuiltinArtifactTool:load_artifact]", ) + + # If not found and we have an agent name, try agent defaults + if result.get("status") == "not_found" and agent_name: + log.debug( + "%s Artifact not found in user session, trying agent defaults for agent '%s'.", + log_identifier, + agent_name, + ) + result = await load_artifact_content_or_metadata( + artifact_service=artifact_service, + app_name=agent_name, # Default artifacts use agent_name as app_name + user_id=AGENT_DEFAULTS_USER_ID, + session_id=agent_name, + filename=filename, + version=version, + load_metadata_only=load_metadata_only, + max_content_length=max_content_length, + include_line_numbers=include_line_numbers, + component=host_component, + log_identifier_prefix="[BuiltinArtifactTool:load_artifact:agent_default]", + ) + if result.get("status") == "success": + # Mark as loaded from agent defaults + result["source"] = "agent_default" + log.info( + "%s Loaded artifact '%s' from agent defaults for agent '%s'.", + log_identifier, + filename, + agent_name, + ) + return result except FileNotFoundError as fnf_err: log.warning( diff --git a/src/solace_agent_mesh/agent/utils/artifact_helpers.py b/src/solace_agent_mesh/agent/utils/artifact_helpers.py index 308d575ff2..fae8440f91 100644 --- a/src/solace_agent_mesh/agent/utils/artifact_helpers.py +++ b/src/solace_agent_mesh/agent/utils/artifact_helpers.py @@ -33,6 +33,9 @@ log = logging.getLogger(__name__) METADATA_SUFFIX = ".metadata.json" + +# Constant for agent-level default artifacts (must match services.py) +AGENT_DEFAULTS_USER_ID = "__agent_defaults__" DEFAULT_SCHEMA_MAX_KEYS = 20 DEFAULT_SCHEMA_INFERENCE_DEPTH = 4 @@ -1018,6 +1021,8 @@ async def get_artifact_info_list( app_name: str, user_id: str, session_id: str, + include_agent_defaults: bool = False, + agent_name: Optional[str] = None, ) -> List[ArtifactInfo]: """ Retrieves detailed information for all artifacts using the artifact service. @@ -1027,12 +1032,15 @@ async def get_artifact_info_list( app_name: The application name. user_id: The user ID. session_id: The session ID. + include_agent_defaults: If True, also include agent-level default artifacts. + agent_name: The agent name (required if include_agent_defaults is True). Returns: A list of ArtifactInfo objects. """ log_prefix = f"[ArtifactHelper:get_info_list] App={app_name}, User={user_id}, Session={session_id} -" artifact_info_list: List[ArtifactInfo] = [] + user_artifact_filenames: set = set() try: list_keys_method = getattr(artifact_service, "list_artifact_keys") @@ -1047,6 +1055,7 @@ async def get_artifact_info_list( if filename.endswith(METADATA_SUFFIX): continue + user_artifact_filenames.add(filename) log_identifier_item = f"{log_prefix} [{filename}]" try: @@ -1142,6 +1151,17 @@ async def get_artifact_info_list( ) ) + # Include agent-level default artifacts if requested + if include_agent_defaults and agent_name: + default_artifacts = await _get_agent_default_artifacts( + artifact_service=artifact_service, + app_name=app_name, + agent_name=agent_name, + exclude_filenames=user_artifact_filenames, + log_prefix=log_prefix, + ) + artifact_info_list.extend(default_artifacts) + except Exception as e: log.exception( "%s Error listing artifact keys or processing list: %s", log_prefix, e @@ -1150,6 +1170,240 @@ async def get_artifact_info_list( return artifact_info_list +async def _get_agent_default_artifacts( + artifact_service: BaseArtifactService, + app_name: str, + agent_name: str, + exclude_filenames: set, + log_prefix: str, +) -> List[ArtifactInfo]: + """ + Retrieves agent-level default artifacts that are not already in the user's session. + + Args: + artifact_service: The artifact service instance. + app_name: The application name (namespace) - NOTE: This is ignored for default artifacts. + Default artifacts are stored with agent_name as the app_name. + agent_name: The agent name (used as both app_name and session_id for defaults). + exclude_filenames: Set of filenames to exclude (user already has these). + log_prefix: Prefix for log messages. + + Returns: + A list of ArtifactInfo objects for default artifacts. + """ + default_artifacts: List[ArtifactInfo] = [] + + # NOTE: Default artifacts are stored with agent_name as the app_name (not the gateway's app_name) + # This is because the agent loads them during initialization using self.agent_name as app_name + default_app_name = agent_name + + try: + list_keys_method = getattr(artifact_service, "list_artifact_keys") + default_keys = await list_keys_method( + app_name=default_app_name, + user_id=AGENT_DEFAULTS_USER_ID, + session_id=agent_name, + ) + + log.debug( + "%s Found %d agent default artifact keys for agent '%s'.", + log_prefix, + len(default_keys), + agent_name, + ) + + for filename in default_keys: + if filename.endswith(METADATA_SUFFIX): + continue + + # Skip if user already has this artifact (user's version takes precedence) + if filename in exclude_filenames: + log.debug( + "%s Skipping default artifact '%s' - user has their own version.", + log_prefix, + filename, + ) + continue + + log_identifier_item = f"{log_prefix} [default:{filename}]" + try: + version_count: int = 0 + latest_version_num: Optional[int] = await get_latest_artifact_version( + artifact_service, default_app_name, AGENT_DEFAULTS_USER_ID, agent_name, filename + ) + + if hasattr(artifact_service, "list_versions"): + try: + available_versions = await artifact_service.list_versions( + app_name=default_app_name, + user_id=AGENT_DEFAULTS_USER_ID, + session_id=agent_name, + filename=filename, + ) + version_count = len(available_versions) + except Exception as list_ver_err: + log.error( + "%s Error listing versions for default artifact: %s.", + log_identifier_item, + list_ver_err, + ) + + data = await load_artifact_content_or_metadata( + artifact_service=artifact_service, + app_name=default_app_name, + user_id=AGENT_DEFAULTS_USER_ID, + session_id=agent_name, + filename=filename, + version="latest", + load_metadata_only=True, + log_identifier_prefix=log_identifier_item, + ) + + metadata = data.get("metadata", {}) + mime_type = metadata.get("mime_type", "application/data") + size = metadata.get("size_bytes", 0) + schema_definition = metadata.get("schema", {}) + description = metadata.get("description", "No description provided") + loaded_version_num = data.get("version", latest_version_num) + + last_modified_ts = metadata.get("timestamp_utc") + last_modified_iso = ( + datetime.fromtimestamp( + last_modified_ts, tz=timezone.utc + ).isoformat() + if last_modified_ts + else None + ) + + # Mark as agent default artifact + default_artifacts.append( + ArtifactInfo( + filename=filename, + mime_type=mime_type, + size=size, + last_modified=last_modified_iso, + schema_definition=schema_definition, + description=description, + version=loaded_version_num, + version_count=version_count, + source="agent_default", # Mark as agent default + ) + ) + log.debug( + "%s Successfully processed default artifact info.", + log_identifier_item, + ) + + except FileNotFoundError: + log.warning( + "%s Default artifact file not found for key '%s'. Skipping.", + log_prefix, + filename, + ) + except Exception as detail_e: + log.error( + "%s Error processing default artifact '%s': %s\n%s", + log_prefix, + filename, + detail_e, + traceback.format_exc(), + ) + + except Exception as e: + log.warning( + "%s Error listing agent default artifacts for agent '%s': %s", + log_prefix, + agent_name, + e, + ) + + return default_artifacts + + +async def load_artifact_with_fallback_to_defaults( + artifact_service: BaseArtifactService, + app_name: str, + user_id: str, + session_id: str, + filename: str, + version: Optional[int] = None, + agent_name: Optional[str] = None, +) -> Optional[Any]: + """ + Loads an artifact, falling back to agent defaults if not found in user session. + + Args: + artifact_service: The artifact service instance. + app_name: The application name (namespace). + user_id: The user ID. + session_id: The session ID. + filename: The name of the artifact to load. + version: Optional specific version to load. + agent_name: The agent name for fallback lookup (if None, no fallback). + + Returns: + The artifact Part object, or None if not found. + """ + log_prefix = f"[ArtifactHelper:load_with_fallback] App={app_name}, User={user_id}, Session={session_id}, File={filename} -" + + # First, try to load from the user's session + try: + artifact_part = await artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + version=version, + ) + + if artifact_part is not None: + log.debug("%s Found artifact in user session.", log_prefix) + return artifact_part + except FileNotFoundError: + log.debug("%s Artifact not found in user session.", log_prefix) + except Exception as e: + log.warning("%s Error loading from user session: %s", log_prefix, e) + + # If not found and we have an agent name, try the agent defaults + # NOTE: Default artifacts are stored with agent_name as the app_name (not the gateway's app_name) + if agent_name: + log.debug( + "%s Trying agent defaults for agent '%s'.", + log_prefix, + agent_name, + ) + try: + artifact_part = await artifact_service.load_artifact( + app_name=agent_name, # Use agent_name as app_name for defaults + user_id=AGENT_DEFAULTS_USER_ID, + session_id=agent_name, + filename=filename, + version=version, + ) + + if artifact_part is not None: + log.debug( + "%s Found artifact in agent defaults for agent '%s'.", + log_prefix, + agent_name, + ) + return artifact_part + except FileNotFoundError: + log.debug( + "%s Artifact not found in agent defaults for agent '%s'.", + log_prefix, + agent_name, + ) + except Exception as e: + log.warning( + "%s Error loading from agent defaults: %s", + log_prefix, + e, + ) + + return None + + async def load_artifact_content_or_metadata( artifact_service: BaseArtifactService, app_name: str, diff --git a/src/solace_agent_mesh/gateway/http_sse/routers/artifacts.py b/src/solace_agent_mesh/gateway/http_sse/routers/artifacts.py index 913c6b6b5f..7b930672c9 100644 --- a/src/solace_agent_mesh/gateway/http_sse/routers/artifacts.py +++ b/src/solace_agent_mesh/gateway/http_sse/routers/artifacts.py @@ -65,6 +65,7 @@ class BaseArtifactService: from ....agent.utils.artifact_helpers import ( get_artifact_info_list, load_artifact_content_or_metadata, + load_artifact_with_fallback_to_defaults, process_artifact_upload, ) @@ -475,11 +476,16 @@ async def list_artifact_versions( validate_session: Callable[[str, str], bool] = Depends(get_session_validator), component: "WebUIBackendComponent" = Depends(get_sac_component), project_service: ProjectService | None = Depends(get_project_service_optional), + session_service: SessionService | None = Depends( + get_session_business_service_optional + ), + db: Session | None = Depends(get_db_optional), user_config: dict = Depends(ValidatedUserConfig(["tool:artifact:list"])), ): """ Lists the available integer versions for a given artifact filename associated with the specified context (session or project). + Falls back to agent default artifacts if not found in user session. """ log_prefix = f"[ArtifactRouter:ListVersions:{filename}] User={user_id}, Session={session_id} -" @@ -511,15 +517,80 @@ async def list_artifact_versions( try: app_name = component.get_config("name", "A2A_WebUI_App") - log.info("%s Using %s context: storage_user_id=%s, storage_session_id=%s", + log.info("%s Using %s context: storage_user_id=%s, storage_session_id=%s", log_prefix, context_type, storage_user_id, storage_session_id) + # Try to get the agent_id for the session/project to enable fallback to default artifacts + agent_name: Optional[str] = None + if context_type == "session" and session_service and db: + try: + session_domain = session_service.get_session_details(db, session_id, user_id) + if session_domain and session_domain.agent_id: + agent_name = session_domain.agent_id + log.debug( + "%s Session is associated with agent '%s', will try fallback if needed.", + log_prefix, + agent_name, + ) + except Exception as session_err: + log.warning( + "%s Could not retrieve session agent info: %s", + log_prefix, + session_err, + ) + elif context_type == "project" and project_service and project_id: + # For project context, look up the project's default_agent_id + from ....gateway.http_sse.dependencies import SessionLocal + + if SessionLocal is not None: + project_db = SessionLocal() + try: + project = project_service.get_project(project_db, project_id, user_id) + if project and project.default_agent_id: + agent_name = project.default_agent_id + log.debug( + "%s Project has default agent '%s', will try fallback if needed.", + log_prefix, + agent_name, + ) + except Exception as project_err: + log.warning( + "%s Could not retrieve project agent info: %s", + log_prefix, + project_err, + ) + finally: + project_db.close() + versions = await artifact_service.list_versions( app_name=app_name, user_id=storage_user_id, session_id=storage_session_id, filename=filename, ) + + # If no versions found and we have an agent name, try agent defaults + if not versions and agent_name: + log.debug( + "%s No versions found in user session, trying agent defaults for agent '%s'.", + log_prefix, + agent_name, + ) + from ....agent.utils.artifact_helpers import AGENT_DEFAULTS_USER_ID + versions = await artifact_service.list_versions( + app_name=agent_name, # Default artifacts use agent_name as app_name + user_id=AGENT_DEFAULTS_USER_ID, + session_id=agent_name, + filename=filename, + ) + if versions: + log.info( + "%s Found versions in agent defaults for agent '%s': %s", + log_prefix, + agent_name, + versions, + ) + log.info("%s Found versions: %s", log_prefix, versions) return versions except FileNotFoundError: @@ -553,16 +624,22 @@ async def list_artifacts( ..., title="Session ID", description="The session ID to list artifacts for (or 'null' for project context)" ), project_id: Optional[str] = Query(None, description="Project ID for project context"), + agent_name: Optional[str] = Query(None, description="Agent name to include default artifacts for (used when switching agents)"), artifact_service: BaseArtifactService = Depends(get_shared_artifact_service), user_id: str = Depends(get_user_id), validate_session: Callable[[str, str], bool] = Depends(get_session_validator), component: "WebUIBackendComponent" = Depends(get_sac_component), project_service: ProjectService | None = Depends(get_project_service_optional), + session_service: SessionService | None = Depends( + get_session_business_service_optional + ), + db: Session | None = Depends(get_db_optional), user_config: dict = Depends(ValidatedUserConfig(["tool:artifact:list"])), ): """ Lists detailed information (filename, size, type, modified date, uri) for all artifacts associated with the specified context (session or project). + Includes agent-level default artifacts if the session is associated with an agent. """ log_prefix = f"[ArtifactRouter:ListInfo] User={user_id}, Session={session_id} -" @@ -588,14 +665,105 @@ async def list_artifacts( try: app_name = component.get_config("name", "A2A_WebUI_App") - log.info("%s Using %s context: storage_user_id=%s, storage_session_id=%s", + log.info("%s Using %s context: storage_user_id=%s, storage_session_id=%s", log_prefix, context_type, storage_user_id, storage_session_id) + # Determine agent name for default artifacts + # Priority 1: Use agent_name from query parameter (when switching agents mid-session) + # Priority 2: Look up from session/project context + effective_agent_name: Optional[str] = agent_name # From query parameter + include_agent_defaults = False + + log.debug( + "%s Checking for agent defaults: query_agent_name=%s, context_type=%s, session_service=%s, db=%s, project_service=%s", + log_prefix, + agent_name, + context_type, + "available" if session_service else "None", + "available" if db else "None", + "available" if project_service else "None", + ) + + # If agent_name was provided via query parameter, use it directly + if effective_agent_name: + include_agent_defaults = True + log.info( + "%s Using agent_name from query parameter: '%s', will include default artifacts.", + log_prefix, + effective_agent_name, + ) + elif context_type == "session" and session_service and db: + # Fall back to looking up agent from session + try: + log.debug("%s Looking up session details for session_id=%s, user_id=%s", log_prefix, session_id, user_id) + session_domain = session_service.get_session_details(db, session_id, user_id) + log.debug("%s Session domain result: %s, agent_id=%s", log_prefix, session_domain, session_domain.agent_id if session_domain else "N/A") + if session_domain and session_domain.agent_id: + effective_agent_name = session_domain.agent_id + include_agent_defaults = True + log.info( + "%s Session is associated with agent '%s', will include default artifacts.", + log_prefix, + effective_agent_name, + ) + else: + log.debug( + "%s Session has no agent_id set yet (session_domain=%s).", + log_prefix, + session_domain, + ) + except Exception as session_err: + log.warning( + "%s Could not retrieve session agent info: %s", + log_prefix, + session_err, + ) + elif context_type == "project" and project_service and project_id: + # For project context, look up the project's default_agent_id + from ....gateway.http_sse.dependencies import SessionLocal + + if SessionLocal is not None: + project_db = SessionLocal() + try: + project = project_service.get_project(project_db, project_id, user_id) + if project and project.default_agent_id: + effective_agent_name = project.default_agent_id + include_agent_defaults = True + log.info( + "%s Project has default agent '%s', will include default artifacts.", + log_prefix, + effective_agent_name, + ) + else: + log.debug( + "%s Project has no default_agent_id set (project=%s).", + log_prefix, + project, + ) + except Exception as project_err: + log.warning( + "%s Could not retrieve project agent info: %s", + log_prefix, + project_err, + ) + finally: + project_db.close() + else: + log.debug( + "%s Skipping agent defaults lookup: context_type=%s, session_service=%s, db=%s", + log_prefix, + context_type, + "available" if session_service else "None", + "available" if db else "None", + ) + artifact_info_list = await get_artifact_info_list( artifact_service=artifact_service, app_name=app_name, user_id=storage_user_id, session_id=storage_session_id, + include_agent_defaults=include_agent_defaults, + agent_name=effective_agent_name, ) log.info("%s Returning %d artifact details.", log_prefix, len(artifact_info_list)) @@ -625,11 +793,16 @@ async def get_latest_artifact( validate_session: Callable[[str, str], bool] = Depends(get_session_validator), component: "WebUIBackendComponent" = Depends(get_sac_component), project_service: ProjectService | None = Depends(get_project_service_optional), + session_service: SessionService | None = Depends( + get_session_business_service_optional + ), + db: Session | None = Depends(get_db_optional), user_config: dict = Depends(ValidatedUserConfig(["tool:artifact:load"])), ): """ Retrieves the content of the latest version of the specified artifact associated with the specified context (session or project). + Falls back to agent default artifacts if not found in user session. """ log_prefix = ( f"[ArtifactRouter:GetLatest:{filename}] User={user_id}, Session={session_id} -" @@ -651,14 +824,59 @@ async def get_latest_artifact( try: app_name = component.get_config("name", "A2A_WebUI_App") - log.info("%s Using %s context: storage_user_id=%s, storage_session_id=%s", + log.info("%s Using %s context: storage_user_id=%s, storage_session_id=%s", log_prefix, context_type, storage_user_id, storage_session_id) - artifact_part = await artifact_service.load_artifact( + # Try to get the agent_id for the session/project to enable fallback to default artifacts + agent_name: Optional[str] = None + if context_type == "session" and session_service and db: + try: + session_domain = session_service.get_session_details(db, session_id, user_id) + if session_domain and session_domain.agent_id: + agent_name = session_domain.agent_id + log.debug( + "%s Session is associated with agent '%s', will try fallback if needed.", + log_prefix, + agent_name, + ) + except Exception as session_err: + log.warning( + "%s Could not retrieve session agent info: %s", + log_prefix, + session_err, + ) + elif context_type == "project" and project_service and project_id: + # For project context, look up the project's default_agent_id + from ....gateway.http_sse.dependencies import SessionLocal + + if SessionLocal is not None: + project_db = SessionLocal() + try: + project = project_service.get_project(project_db, project_id, user_id) + if project and project.default_agent_id: + agent_name = project.default_agent_id + log.debug( + "%s Project has default agent '%s', will try fallback if needed.", + log_prefix, + agent_name, + ) + except Exception as project_err: + log.warning( + "%s Could not retrieve project agent info: %s", + log_prefix, + project_err, + ) + finally: + project_db.close() + + # Use the fallback function to load artifact (tries user session first, then agent defaults) + artifact_part = await load_artifact_with_fallback_to_defaults( + artifact_service=artifact_service, app_name=app_name, user_id=storage_user_id, session_id=storage_session_id, filename=filename, + agent_name=agent_name, ) if artifact_part is None or artifact_part.inline_data is None: @@ -791,11 +1009,16 @@ async def get_specific_artifact_version( validate_session: Callable[[str, str], bool] = Depends(get_session_validator), component: "WebUIBackendComponent" = Depends(get_sac_component), project_service: ProjectService | None = Depends(get_project_service_optional), + session_service: SessionService | None = Depends( + get_session_business_service_optional + ), + db: Session | None = Depends(get_db_optional), user_config: dict = Depends(ValidatedUserConfig(["tool:artifact:load"])), ): """ Retrieves the content of a specific version of the specified artifact associated with the specified context (session or project). + Falls back to agent default artifacts if not found in user session. """ log_prefix = f"[ArtifactRouter:GetVersion:{filename} v{version}] User={user_id}, Session={session_id} -" log.info("%s Request received.", log_prefix) @@ -815,9 +1038,51 @@ async def get_specific_artifact_version( try: app_name = component.get_config("name", "A2A_WebUI_App") - log.info("%s Using %s context: storage_user_id=%s, storage_session_id=%s", + log.info("%s Using %s context: storage_user_id=%s, storage_session_id=%s", log_prefix, context_type, storage_user_id, storage_session_id) + # Try to get the agent_id for the session/project to enable fallback to default artifacts + agent_name: Optional[str] = None + if context_type == "session" and session_service and db: + try: + session_domain = session_service.get_session_details(db, session_id, user_id) + if session_domain and session_domain.agent_id: + agent_name = session_domain.agent_id + log.debug( + "%s Session is associated with agent '%s', will try fallback if needed.", + log_prefix, + agent_name, + ) + except Exception as session_err: + log.warning( + "%s Could not retrieve session agent info: %s", + log_prefix, + session_err, + ) + elif context_type == "project" and project_service and project_id: + # For project context, look up the project's default_agent_id + from ....gateway.http_sse.dependencies import SessionLocal + + if SessionLocal is not None: + project_db = SessionLocal() + try: + project = project_service.get_project(project_db, project_id, user_id) + if project and project.default_agent_id: + agent_name = project.default_agent_id + log.debug( + "%s Project has default agent '%s', will try fallback if needed.", + log_prefix, + agent_name, + ) + except Exception as project_err: + log.warning( + "%s Could not retrieve project agent info: %s", + log_prefix, + project_err, + ) + finally: + project_db.close() + load_result = await load_artifact_content_or_metadata( artifact_service=artifact_service, app_name=app_name, @@ -830,6 +1095,33 @@ async def get_specific_artifact_version( log_identifier_prefix="[ArtifactRouter:GetVersion]", ) + # If not found and we have an agent name, try agent defaults + if load_result.get("status") == "not_found" and agent_name: + log.debug( + "%s Artifact not found in user session, trying agent defaults for agent '%s'.", + log_prefix, + agent_name, + ) + from ....agent.utils.artifact_helpers import AGENT_DEFAULTS_USER_ID + load_result = await load_artifact_content_or_metadata( + artifact_service=artifact_service, + app_name=agent_name, # Default artifacts use agent_name as app_name + user_id=AGENT_DEFAULTS_USER_ID, + session_id=agent_name, + filename=filename, + version=version, + load_metadata_only=False, + return_raw_bytes=True, + log_identifier_prefix="[ArtifactRouter:GetVersion:agent_default]", + ) + if load_result.get("status") == "success": + log.info( + "%s Loaded artifact '%s' from agent defaults for agent '%s'.", + log_prefix, + filename, + agent_name, + ) + if load_result.get("status") != "success": error_message = load_result.get( "message", f"Failed to load artifact '{filename}' version '{version}'." diff --git a/tests/unit/agent/test_default_artifacts.py b/tests/unit/agent/test_default_artifacts.py new file mode 100644 index 0000000000..ebadb6cb46 --- /dev/null +++ b/tests/unit/agent/test_default_artifacts.py @@ -0,0 +1,436 @@ +""" +Unit tests for the default artifacts feature. + +This module tests the ability to configure default artifacts for agents +that are automatically available to all users without requiring upload. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import os +import tempfile + +from google.genai import types as adk_types + + +class TestDefaultArtifactConfig: + """Tests for the DefaultArtifactConfig schema.""" + + def test_default_artifact_config_minimal(self): + """Test creating a DefaultArtifactConfig with only required fields.""" + from solace_agent_mesh.agent.sac.app import DefaultArtifactConfig + + # Both path and filename are required + config = DefaultArtifactConfig(path="/path/to/file.txt", filename="file.txt") + assert config.path == "/path/to/file.txt" + assert config.filename == "file.txt" + assert config.mime_type is None + assert config.description is None + + def test_default_artifact_config_full(self): + """Test creating a DefaultArtifactConfig with all fields.""" + from solace_agent_mesh.agent.sac.app import DefaultArtifactConfig + + config = DefaultArtifactConfig( + path="/path/to/file.txt", + filename="custom_name.txt", + mime_type="text/plain", + description="A test file", + ) + assert config.path == "/path/to/file.txt" + assert config.filename == "custom_name.txt" + assert config.mime_type == "text/plain" + assert config.description == "A test file" + + def test_default_artifact_config_missing_required_fields(self): + """Test that missing required fields raise validation error.""" + from solace_agent_mesh.agent.sac.app import DefaultArtifactConfig + from pydantic import ValidationError + + # Missing filename should raise error + with pytest.raises(ValidationError): + DefaultArtifactConfig(path="/path/to/file.txt") + + # Missing path should raise error + with pytest.raises(ValidationError): + DefaultArtifactConfig(filename="file.txt") + + +class TestScopedArtifactServiceWrapper: + """Tests for the ScopedArtifactServiceWrapper with default artifacts support.""" + + @pytest.fixture + def mock_wrapped_service(self): + """Create a mock wrapped artifact service.""" + service = AsyncMock() + service.save_artifact = AsyncMock(return_value=1) + service.load_artifact = AsyncMock(return_value=None) + service.list_artifact_keys = AsyncMock(return_value=[]) + service.delete_artifact = AsyncMock(return_value=None) + service.list_versions = AsyncMock(return_value=[]) + return service + + @pytest.fixture + def mock_component(self): + """Create a mock component.""" + component = MagicMock() + component.agent_name = "test_agent" + component.namespace = "test_namespace" + component.log_identifier = "[TestAgent]" + component.get_config = MagicMock(side_effect=lambda key, default=None: { + "default_artifacts": [], + "artifact_scope": "namespace", + "agent_name": "test_agent", + }.get(key, default)) + return component + + @pytest.mark.asyncio + async def test_load_artifact_fallback_to_defaults(self, mock_wrapped_service, mock_component): + """Test that load_artifact falls back to agent defaults when user artifact not found.""" + from solace_agent_mesh.agent.adk.services import ( + ScopedArtifactServiceWrapper, + AGENT_DEFAULTS_USER_ID, + ) + + # Configure default artifacts + mock_component.get_config = MagicMock(side_effect=lambda key, default=None: { + "default_artifacts": [{"path": "/test/file.txt", "filename": "test.txt"}], + "artifact_scope": "namespace", + "agent_name": "test_agent", + }.get(key, default)) + + # Create wrapper with new API + wrapper = ScopedArtifactServiceWrapper( + wrapped_service=mock_wrapped_service, + component=mock_component, + ) + + # Mock: user artifact not found, but default exists + default_artifact = adk_types.Part.from_text(text="default content") + mock_wrapped_service.load_artifact = AsyncMock( + side_effect=[None, default_artifact] # First call (user) returns None, second (default) returns artifact + ) + + # Load artifact + result = await wrapper.load_artifact( + app_name="test_agent", + user_id="user123", + session_id="session456", + filename="test.txt", + ) + + # Verify fallback was used + assert result == default_artifact + assert mock_wrapped_service.load_artifact.call_count == 2 + + # Verify first call was for user + first_call = mock_wrapped_service.load_artifact.call_args_list[0] + assert first_call.kwargs["user_id"] == "user123" + + # Verify second call was for defaults + second_call = mock_wrapped_service.load_artifact.call_args_list[1] + assert second_call.kwargs["user_id"] == AGENT_DEFAULTS_USER_ID + + @pytest.mark.asyncio + async def test_load_artifact_user_takes_precedence(self, mock_wrapped_service, mock_component): + """Test that user's artifact takes precedence over default.""" + from solace_agent_mesh.agent.adk.services import ScopedArtifactServiceWrapper + + # Configure default artifacts + mock_component.get_config = MagicMock(side_effect=lambda key, default=None: { + "default_artifacts": [{"path": "/test/file.txt", "filename": "test.txt"}], + "artifact_scope": "namespace", + "agent_name": "test_agent", + }.get(key, default)) + + wrapper = ScopedArtifactServiceWrapper( + wrapped_service=mock_wrapped_service, + component=mock_component, + ) + + # Mock: user artifact exists + user_artifact = adk_types.Part.from_text(text="user content") + mock_wrapped_service.load_artifact = AsyncMock(return_value=user_artifact) + + # Load artifact + result = await wrapper.load_artifact( + app_name="test_agent", + user_id="user123", + session_id="session456", + filename="test.txt", + ) + + # Verify user artifact was returned + assert result == user_artifact + # Should only call once (no fallback needed) + assert mock_wrapped_service.load_artifact.call_count == 1 + + @pytest.mark.asyncio + async def test_list_artifact_keys_merges_defaults(self, mock_wrapped_service, mock_component): + """Test that list_artifact_keys includes both user and default artifacts.""" + from solace_agent_mesh.agent.adk.services import ( + ScopedArtifactServiceWrapper, + AGENT_DEFAULTS_USER_ID, + ) + + # Configure default artifacts + mock_component.get_config = MagicMock(side_effect=lambda key, default=None: { + "default_artifacts": [ + {"path": "/test/default1.txt", "filename": "default1.txt"}, + {"path": "/test/default2.txt", "filename": "default2.txt"}, + ], + "artifact_scope": "namespace", + "agent_name": "test_agent", + }.get(key, default)) + + wrapper = ScopedArtifactServiceWrapper( + wrapped_service=mock_wrapped_service, + component=mock_component, + ) + + # Mock: user has one artifact, defaults have two + mock_wrapped_service.list_artifact_keys = AsyncMock( + side_effect=[ + ["user_file.txt"], # User artifacts + ["default1.txt", "default2.txt"], # Default artifacts + ] + ) + + # List artifacts + result = await wrapper.list_artifact_keys( + app_name="test_agent", + user_id="user123", + session_id="session456", + ) + + # Verify merged list (unique keys) + assert set(result) == {"user_file.txt", "default1.txt", "default2.txt"} + + @pytest.mark.asyncio + async def test_delete_artifact_prevents_default_deletion(self, mock_wrapped_service, mock_component): + """Test that deleting a default artifact raises PermissionError.""" + from solace_agent_mesh.agent.adk.services import ScopedArtifactServiceWrapper + + # Configure default artifacts + mock_component.get_config = MagicMock(side_effect=lambda key, default=None: { + "default_artifacts": [{"path": "/test/protected.txt", "filename": "protected.txt"}], + "artifact_scope": "namespace", + "agent_name": "test_agent", + }.get(key, default)) + + wrapper = ScopedArtifactServiceWrapper( + wrapped_service=mock_wrapped_service, + component=mock_component, + ) + + # Mock: user doesn't have this artifact (it's a default) + mock_wrapped_service.load_artifact = AsyncMock(return_value=None) + + # Attempt to delete default artifact should raise PermissionError + with pytest.raises(PermissionError, match="Cannot delete default artifact"): + await wrapper.delete_artifact( + app_name="test_agent", + user_id="user123", + session_id="session456", + filename="protected.txt", + ) + + @pytest.mark.asyncio + async def test_save_artifact_allows_shadowing(self, mock_wrapped_service, mock_component): + """Test that users can save their own version of a default artifact (shadowing).""" + from solace_agent_mesh.agent.adk.services import ScopedArtifactServiceWrapper + + # Configure default artifacts + mock_component.get_config = MagicMock(side_effect=lambda key, default=None: { + "default_artifacts": [{"path": "/test/file.txt", "filename": "file.txt"}], + "artifact_scope": "namespace", + "agent_name": "test_agent", + }.get(key, default)) + + wrapper = ScopedArtifactServiceWrapper( + wrapped_service=mock_wrapped_service, + component=mock_component, + ) + + # Save user's version of the artifact + user_artifact = adk_types.Part.from_text(text="user's version") + mock_wrapped_service.save_artifact = AsyncMock(return_value=1) + + result = await wrapper.save_artifact( + app_name="test_agent", + user_id="user123", + session_id="session456", + filename="file.txt", + artifact=user_artifact, + ) + + # Verify save was called with user's credentials + assert result == 1 + mock_wrapped_service.save_artifact.assert_called_once() + call_kwargs = mock_wrapped_service.save_artifact.call_args.kwargs + assert call_kwargs["user_id"] == "user123" + + +class TestLoadDefaultArtifacts: + """Tests for the _load_default_artifacts method in SamAgentComponent.""" + + @pytest.mark.asyncio + async def test_load_default_artifacts_from_file(self): + """Test loading default artifacts from actual files.""" + from solace_agent_mesh.agent.adk.services import AGENT_DEFAULTS_USER_ID + + # Create a temporary file + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("Test content for default artifact") + temp_file_path = f.name + + try: + # Create mock component + mock_component = MagicMock() + mock_component.agent_name = "test_agent" + mock_component.namespace = "test_namespace" + mock_component.log_identifier = "[TestAgent]" + mock_component.artifact_service = AsyncMock() + + # Mock list_versions to return empty (artifact doesn't exist yet) + mock_component.artifact_service.list_versions = AsyncMock(return_value=[]) + + mock_component.get_config = MagicMock( + side_effect=lambda key, default=None: { + "default_artifacts": [ + { + "path": temp_file_path, + "filename": "test_default.txt", + "mime_type": "text/plain", + "description": "A test default artifact", + } + ], + "base_path": ".", + "artifact_scope": "namespace", + }.get(key, default) + ) + + # Import and call the method + from solace_agent_mesh.agent.sac.component import SamAgentComponent + + # Mock save_artifact_with_metadata - patch where it's used (in component module) + with patch('solace_agent_mesh.agent.utils.artifact_helpers.save_artifact_with_metadata', new_callable=AsyncMock) as mock_save: + mock_save.return_value = {"status": "success", "version": 0} + + # Call the method directly (we need to bind it to our mock) + await SamAgentComponent._load_default_artifacts(mock_component) + + # Verify save_artifact_with_metadata was called + mock_save.assert_called_once() + call_kwargs = mock_save.call_args.kwargs + assert call_kwargs["app_name"] == "test_agent" + assert call_kwargs["user_id"] == AGENT_DEFAULTS_USER_ID + assert call_kwargs["filename"] == "test_default.txt" + + finally: + # Cleanup + os.unlink(temp_file_path) + + @pytest.mark.asyncio + async def test_load_default_artifacts_handles_missing_file(self): + """Test that missing files are handled gracefully.""" + mock_component = MagicMock() + mock_component.agent_name = "test_agent" + mock_component.namespace = "test_namespace" + mock_component.log_identifier = "[TestAgent]" + mock_component.artifact_service = AsyncMock() + mock_component.get_config = MagicMock( + side_effect=lambda key, default=None: { + "default_artifacts": [ + {"path": "/nonexistent/file.txt", "filename": "missing.txt"} + ], + "base_path": ".", + }.get(key, default) + ) + + from solace_agent_mesh.agent.sac.component import SamAgentComponent + + # Should not raise, just log error + await SamAgentComponent._load_default_artifacts(mock_component) + + # Verify save was NOT called (file doesn't exist) + mock_component.artifact_service.save_artifact.assert_not_called() + + @pytest.mark.asyncio + async def test_load_default_artifacts_no_config(self): + """Test that no action is taken when no default artifacts are configured.""" + mock_component = MagicMock() + mock_component.agent_name = "test_agent" + mock_component.namespace = "test_namespace" + mock_component.log_identifier = "[TestAgent]" + mock_component.artifact_service = AsyncMock() + mock_component.get_config = MagicMock(return_value=[]) + + from solace_agent_mesh.agent.sac.component import SamAgentComponent + + await SamAgentComponent._load_default_artifacts(mock_component) + + # Verify save was NOT called + mock_component.artifact_service.save_artifact.assert_not_called() + + @pytest.mark.asyncio + async def test_load_default_artifacts_skips_existing(self): + """Test that existing artifacts are not re-loaded.""" + from solace_agent_mesh.agent.adk.services import AGENT_DEFAULTS_USER_ID + + # Create a temporary file + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("Test content for default artifact") + temp_file_path = f.name + + try: + mock_component = MagicMock() + mock_component.agent_name = "test_agent" + mock_component.namespace = "test_namespace" + mock_component.log_identifier = "[TestAgent]" + mock_component.artifact_service = AsyncMock() + + # Mock list_versions to return existing version (artifact already exists) + mock_component.artifact_service.list_versions = AsyncMock(return_value=[0]) + + mock_component.get_config = MagicMock( + side_effect=lambda key, default=None: { + "default_artifacts": [ + { + "path": temp_file_path, + "filename": "test_default.txt", + "mime_type": "text/plain", + "description": "A test default artifact", + } + ], + "base_path": ".", + "artifact_scope": "namespace", + }.get(key, default) + ) + + from solace_agent_mesh.agent.sac.component import SamAgentComponent + + # Mock save_artifact_with_metadata - patch where it's used (in artifact_helpers module) + with patch('solace_agent_mesh.agent.utils.artifact_helpers.save_artifact_with_metadata', new_callable=AsyncMock) as mock_save: + mock_save.return_value = {"status": "success", "version": 0} + + await SamAgentComponent._load_default_artifacts(mock_component) + + # Verify save_artifact_with_metadata was NOT called (artifact already exists) + mock_save.assert_not_called() + + finally: + os.unlink(temp_file_path) + + +class TestAgentDefaultsUserIdConstant: + """Tests for the AGENT_DEFAULTS_USER_ID constant.""" + + def test_constant_value(self): + """Test that the constant has the expected value.""" + from solace_agent_mesh.agent.adk.services import AGENT_DEFAULTS_USER_ID + + assert AGENT_DEFAULTS_USER_ID == "__agent_defaults__" + # Ensure it's a string that won't conflict with real user IDs + assert AGENT_DEFAULTS_USER_ID.startswith("__") + assert AGENT_DEFAULTS_USER_ID.endswith("__")