diff --git a/TESTING.md b/TESTING.md index ccf64cfa0..bf4d0c291 100644 --- a/TESTING.md +++ b/TESTING.md @@ -291,7 +291,7 @@ class TestExampleService: def test_with_database(db_session): """Test using database session fixture.""" # db_session is automatically provided by conftest.py - from mcpgateway.models import Tool + from mcpgateway.common.models import Tool tool = Tool(name="test_tool") db_session.add(tool) db_session.commit() diff --git a/docs/docs/architecture/multitenancy.md b/docs/docs/architecture/multitenancy.md index 01389d295..f7083c266 100644 --- a/docs/docs/architecture/multitenancy.md +++ b/docs/docs/architecture/multitenancy.md @@ -652,7 +652,7 @@ For emergency password resets, you can update the database directly: python3 -c " from mcpgateway.services.argon2_service import Argon2PasswordService from mcpgateway.db import SessionLocal -from mcpgateway.models import EmailUser +from mcpgateway.common.models import EmailUser service = Argon2PasswordService() hashed = service.hash_password('new_password') diff --git a/mcp-servers/templates/go/copier.yaml b/mcp-servers/templates/go/copier.yaml index e6d615ea5..61d2cb223 100644 --- a/mcp-servers/templates/go/copier.yaml +++ b/mcp-servers/templates/go/copier.yaml @@ -45,4 +45,3 @@ include_container: type: bool help: Include Dockerfile for a minimal runtime image default: true - diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index a7c5571f3..d02fd920c 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -51,12 +51,12 @@ from starlette.datastructures import UploadFile as StarletteUploadFile # First-Party +from mcpgateway.common.models import LogLevel from mcpgateway.config import settings from mcpgateway.db import get_db, GlobalConfig, ObservabilitySavedQuery, ObservabilitySpan, ObservabilityTrace from mcpgateway.db import Tool as DbTool from mcpgateway.db import utc_now from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission -from mcpgateway.models import LogLevel from mcpgateway.schemas import ( A2AAgentCreate, A2AAgentRead, diff --git a/mcpgateway/cache/session_registry.py b/mcpgateway/cache/session_registry.py index db26ed2c2..07468f8f6 100644 --- a/mcpgateway/cache/session_registry.py +++ b/mcpgateway/cache/session_registry.py @@ -65,9 +65,9 @@ # First-Party from mcpgateway import __version__ +from mcpgateway.common.models import Implementation, InitializeResult, ServerCapabilities from mcpgateway.config import settings from mcpgateway.db import get_db, SessionMessageRecord, SessionRecord -from mcpgateway.models import Implementation, InitializeResult, ServerCapabilities from mcpgateway.services import PromptService, ResourceService, ToolService from mcpgateway.services.logging_service import LoggingService from mcpgateway.transports import SSETransport diff --git a/mcpgateway/common/__init__.py b/mcpgateway/common/__init__.py new file mode 100644 index 000000000..2f4c65db1 --- /dev/null +++ b/mcpgateway/common/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/common/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Common ContextForge package for shared classes and functions. +""" diff --git a/mcpgateway/common/config.py b/mcpgateway/common/config.py new file mode 100644 index 000000000..5ab271fb2 --- /dev/null +++ b/mcpgateway/common/config.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/config.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti, Manav Gupta + +Common MCP Gateway Configuration settings used across subpackages. +This module defines configuration settings for the MCP Gateway using Pydantic. +It loads configuration from environment variables with sensible defaults. +""" + +# Standard +from functools import lru_cache + +# Third-Party +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + """Validation settings for the security validator.""" + + # Validation patterns for safe display (configurable) + validation_dangerous_html_pattern: str = ( + r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" + ) + + validation_dangerous_js_pattern: str = r"(?i)(?:^|\s|[\"'`<>=])(javascript:|vbscript:|data:\s*[^,]*[;\s]*(javascript|vbscript)|\bon[a-z]+\s*=|<\s*script\b)" + + validation_allowed_url_schemes: list[str] = ["http://", "https://", "ws://", "wss://"] + + # Character validation patterns + validation_name_pattern: str = r"^[a-zA-Z0-9_.\-\s]+$" # Allow spaces for names + validation_identifier_pattern: str = r"^[a-zA-Z0-9_\-\.]+$" # No spaces for IDs + validation_safe_uri_pattern: str = r"^[a-zA-Z0-9_\-.:/?=&%]+$" + validation_unsafe_uri_pattern: str = r'[<>"\'\\]' + validation_tool_name_pattern: str = r"^[a-zA-Z][a-zA-Z0-9._-]*$" # MCP tool naming + validation_tool_method_pattern: str = r"^[a-zA-Z][a-zA-Z0-9_\./-]*$" + + # MCP-compliant size limits (configurable via env) + validation_max_name_length: int = 255 + validation_max_description_length: int = 8192 # 8KB + validation_max_template_length: int = 65536 # 64KB + validation_max_content_length: int = 1048576 # 1MB + validation_max_json_depth: int = 10 + validation_max_url_length: int = 2048 + validation_max_rpc_param_size: int = 262144 # 256KB + + validation_max_method_length: int = 128 + + # Allowed MIME types + validation_allowed_mime_types: list[str] = [ + "text/plain", + "text/html", + "text/css", + "text/markdown", + "text/javascript", + "application/json", + "application/xml", + "application/pdf", + "image/png", + "image/jpeg", + "image/gif", + "image/svg+xml", + "application/octet-stream", + ] + + # Rate limiting + validation_max_requests_per_minute: int = 60 + + # CLI settings + plugins_cli_markup_mode: str | None = None + plugins_cli_completion: bool = True + + +@lru_cache() +def get_settings() -> Settings: + """Get cached settings instance. + + Returns: + Settings: A cached instance of the Settings class. + + Examples: + >>> settings = get_settings() + >>> isinstance(settings, Settings) + True + >>> # Second call returns the same cached instance + >>> settings2 = get_settings() + >>> settings is settings2 + True + """ + # Instantiate a fresh Pydantic Settings object, + # loading from env vars or .env exactly once. + cfg = Settings() + # Validate that transport_type is correct; will + # raise if mis-configured. + # cfg.validate_transport() + # Ensure sqlite DB directories exist if needed. + # cfg.validate_database() + # Return the one-and-only Settings instance (cached). + return cfg + + +# Create settings instance +settings = get_settings() diff --git a/mcpgateway/models.py b/mcpgateway/common/models.py similarity index 98% rename from mcpgateway/models.py rename to mcpgateway/common/models.py index b868d9d18..f8704e917 100644 --- a/mcpgateway/models.py +++ b/mcpgateway/common/models.py @@ -16,7 +16,7 @@ - Capability definitions Examples: - >>> from mcpgateway.models import Role, LogLevel, TextContent + >>> from mcpgateway.common.models import Role, LogLevel, TextContent >>> Role.USER.value 'user' >>> Role.ASSISTANT.value @@ -1360,3 +1360,20 @@ class PermissionAudit(BaseModel): # Permission constants are imported from db.py to avoid duplication # Use Permissions class from mcpgateway.db instead of duplicate SystemPermissions + + +class TransportType(str, Enum): + """ + Enumeration of supported transport mechanisms for communication between components. + + Attributes: + SSE (str): Server-Sent Events transport. + HTTP (str): Standard HTTP-based transport. + STDIO (str): Standard input/output transport. + STREAMABLEHTTP (str): HTTP transport with streaming. + """ + + SSE = "SSE" + HTTP = "HTTP" + STDIO = "STDIO" + STREAMABLEHTTP = "STREAMABLEHTTP" diff --git a/mcpgateway/common/validators.py b/mcpgateway/common/validators.py new file mode 100644 index 000000000..4e8f2fa11 --- /dev/null +++ b/mcpgateway/common/validators.py @@ -0,0 +1,1190 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/common/validators.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti, Madhav Kandukuri + +SecurityValidator for MCP Gateway +This module defines the `SecurityValidator` class, which provides centralized, configurable +validation logic for user-generated content in MCP-based applications. + +The validator enforces strict security and structural rules across common input types such as: +- Display text (e.g., names, descriptions) +- Identifiers and tool names +- URIs and URLs +- JSON object depth +- Templates (including limited HTML/Jinja2) +- MIME types + +Key Features: +- Pattern-based validation using settings-defined regex for HTML/script safety +- Configurable max lengths and depth limits +- Whitelist-based URL scheme and MIME type validation +- Safe escaping of user-visible text fields +- Reusable static/class methods for field-level and form-level validation + +Intended to be used with Pydantic or similar schema-driven systems to validate and sanitize +user input in a consistent, centralized way. + +Dependencies: +- Standard Library: re, html, logging, urllib.parse +- First-party: `settings` from `mcpgateway.config` + +Example usage: + SecurityValidator.validate_name("my_tool", field_name="Tool Name") + SecurityValidator.validate_url("https://example.com") + SecurityValidator.validate_json_depth({...}) + +Examples: + >>> from mcpgateway.common.validators import SecurityValidator + >>> SecurityValidator.sanitize_display_text('Test', 'test') + '<b>Test</b>' + >>> SecurityValidator.validate_name('valid_name-123', 'test') + 'valid_name-123' + >>> SecurityValidator.validate_identifier('my.test.id_123', 'test') + 'my.test.id_123' + >>> SecurityValidator.validate_json_depth({'a': {'b': 1}}) + >>> SecurityValidator.validate_json_depth({'a': 1}) +""" + +# Standard +import html +import logging +import re +from urllib.parse import urlparse +import uuid + +# First-Party +from mcpgateway.common.config import settings + +logger = logging.getLogger(__name__) + + +class SecurityValidator: + """Configurable validation with MCP-compliant limits""" + + # Configurable patterns (from settings) + DANGEROUS_HTML_PATTERN = ( + settings.validation_dangerous_html_pattern + ) # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' + DANGEROUS_JS_PATTERN = settings.validation_dangerous_js_pattern # Default: javascript:|vbscript:|on\w+\s*=|data:.*script + ALLOWED_URL_SCHEMES = settings.validation_allowed_url_schemes # Default: ["http://", "https://", "ws://", "wss://"] + + # Character type patterns + NAME_PATTERN = settings.validation_name_pattern # Default: ^[a-zA-Z0-9_\-\s]+$ + IDENTIFIER_PATTERN = settings.validation_identifier_pattern # Default: ^[a-zA-Z0-9_\-\.]+$ + VALIDATION_SAFE_URI_PATTERN = settings.validation_safe_uri_pattern # Default: ^[a-zA-Z0-9_\-.:/?=&%]+$ + VALIDATION_UNSAFE_URI_PATTERN = settings.validation_unsafe_uri_pattern # Default: [<>"\'\\] + TOOL_NAME_PATTERN = settings.validation_tool_name_pattern # Default: ^[a-zA-Z][a-zA-Z0-9_-]*$ + + # MCP-compliant limits (configurable) + MAX_NAME_LENGTH = settings.validation_max_name_length # Default: 255 + MAX_DESCRIPTION_LENGTH = settings.validation_max_description_length # Default: 8192 (8KB) + MAX_TEMPLATE_LENGTH = settings.validation_max_template_length # Default: 65536 + MAX_CONTENT_LENGTH = settings.validation_max_content_length # Default: 1048576 (1MB) + MAX_JSON_DEPTH = settings.validation_max_json_depth # Default: 10 + MAX_URL_LENGTH = settings.validation_max_url_length # Default: 2048 + + @classmethod + def sanitize_display_text(cls, value: str, field_name: str) -> str: + """Ensure text is safe for display in UI by escaping special characters + + Args: + value (str): Value to validate + field_name (str): Name of field being validated + + Returns: + str: Value if acceptable + + Raises: + ValueError: When input is not acceptable + + Examples: + Basic HTML escaping: + + >>> SecurityValidator.sanitize_display_text('Hello World', 'test') + 'Hello World' + >>> SecurityValidator.sanitize_display_text('Hello World', 'test') + 'Hello <b>World</b>' + + Empty/None handling: + + >>> SecurityValidator.sanitize_display_text('', 'test') + '' + >>> SecurityValidator.sanitize_display_text(None, 'test') #doctest: +SKIP + + Dangerous script patterns: + + >>> SecurityValidator.sanitize_display_text('alert();', 'test') + 'alert();' + >>> SecurityValidator.sanitize_display_text('javascript:alert(1)', 'test') + Traceback (most recent call last): + ... + ValueError: test contains script patterns that may cause display issues + + Polyglot attack patterns: + + >>> SecurityValidator.sanitize_display_text('"; alert()', 'test') + Traceback (most recent call last): + ... + ValueError: test contains potentially dangerous character sequences + >>> SecurityValidator.sanitize_display_text('-->test', 'test') + '-->test' + >>> SecurityValidator.sanitize_display_text('-->') + Traceback (most recent call last): + ... + ValueError: Template contains HTML tags that may interfere with proper display + >>> SecurityValidator.validate_template('Test ') + Traceback (most recent call last): + ... + ValueError: Template contains HTML tags that may interfere with proper display + >>> SecurityValidator.validate_template('
') + Traceback (most recent call last): + ... + ValueError: Template contains HTML tags that may interfere with proper display + + Event handlers blocked: + + >>> SecurityValidator.validate_template('
Test
') + Traceback (most recent call last): + ... + ValueError: Template contains event handlers that may cause display issues + >>> SecurityValidator.validate_template('onload = "alert(1)"') + Traceback (most recent call last): + ... + ValueError: Template contains event handlers that may cause display issues + + SSTI prevention patterns: + + >>> SecurityValidator.validate_template('{{ __import__ }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{{ config }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{% import os %}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{{ 7*7 }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{{ 10/2 }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{{ 5+5 }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{{ 10-5 }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + + Other template injection patterns: + + >>> SecurityValidator.validate_template('${evil}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('#{evil}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('%{evil}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + + Length limit testing: + + >>> long_template = 'a' * 65537 + >>> SecurityValidator.validate_template(long_template) + Traceback (most recent call last): + ... + ValueError: Template exceeds maximum length of 65536 + """ + if not value: + return value + + if len(value) > cls.MAX_TEMPLATE_LENGTH: + raise ValueError(f"Template exceeds maximum length of {cls.MAX_TEMPLATE_LENGTH}") + + # Block dangerous tags but allow Jinja2 syntax {{ }} and {% %} + dangerous_tags = r"<(script|iframe|object|embed|link|meta|base|form)\b" + if re.search(dangerous_tags, value, re.IGNORECASE): + raise ValueError("Template contains HTML tags that may interfere with proper display") + + # Check for event handlers that could cause issues + if re.search(r"on\w+\s*=", value, re.IGNORECASE): + raise ValueError("Template contains event handlers that may cause display issues") + + # SSTI Prevention - block dangerous template expressions + ssti_patterns = [ + r"\{\{.*(__|\.|config|self|request|application|globals|builtins|import).*\}\}", # Jinja2 dangerous patterns + r"\{%.*(__|\.|config|self|request|application|globals|builtins|import).*%\}", # Jinja2 tags + r"\$\{.*\}", # ${} expressions + r"#\{.*\}", # #{} expressions + r"%\{.*\}", # %{} expressions + r"\{\{.*\*.*\}\}", # Math operations in templates (like {{7*7}}) + r"\{\{.*\/.*\}\}", # Division operations + r"\{\{.*\+.*\}\}", # Addition operations + r"\{\{.*\-.*\}\}", # Subtraction operations + ] + + for pattern in ssti_patterns: + if re.search(pattern, value, re.IGNORECASE): + raise ValueError("Template contains potentially dangerous expressions") + + return value + + @classmethod + def validate_url(cls, value: str, field_name: str = "URL") -> str: + """Validate URLs for allowed schemes and safe display + + Args: + value (str): Value to validate + field_name (str): Name of field being validated + + Returns: + str: Value if acceptable + + Raises: + ValueError: When input is not acceptable + + Examples: + Valid URLs: + + >>> SecurityValidator.validate_url('https://example.com') + 'https://example.com' + >>> SecurityValidator.validate_url('http://example.com') + 'http://example.com' + >>> SecurityValidator.validate_url('ws://example.com') + 'ws://example.com' + >>> SecurityValidator.validate_url('wss://example.com') + 'wss://example.com' + >>> SecurityValidator.validate_url('https://example.com:8080/path') + 'https://example.com:8080/path' + >>> SecurityValidator.validate_url('https://example.com/path?query=value') + 'https://example.com/path?query=value' + + Empty URL handling: + + >>> SecurityValidator.validate_url('') + Traceback (most recent call last): + ... + ValueError: URL cannot be empty + + Length validation: + + >>> long_url = 'https://example.com/' + 'a' * 2100 + >>> SecurityValidator.validate_url(long_url) + Traceback (most recent call last): + ... + ValueError: URL exceeds maximum length of 2048 + + Scheme validation: + + >>> SecurityValidator.validate_url('ftp://example.com') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('file:///etc/passwd') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('javascript:alert(1)') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('data:text/plain,hello') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('vbscript:alert(1)') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('about:blank') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('chrome://settings') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('mailto:test@example.com') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + + IPv6 URL blocking: + + >>> SecurityValidator.validate_url('https://[::1]:8080/') + Traceback (most recent call last): + ... + ValueError: URL contains IPv6 address which is not supported + >>> SecurityValidator.validate_url('https://[2001:db8::1]/') + Traceback (most recent call last): + ... + ValueError: URL contains IPv6 address which is not supported + + Protocol-relative URL blocking: + + >>> SecurityValidator.validate_url('//example.com/path') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + + Line break injection: + + >>> SecurityValidator.validate_url('https://example.com\\rHost: evil.com') + Traceback (most recent call last): + ... + ValueError: URL contains line breaks which are not allowed + >>> SecurityValidator.validate_url('https://example.com\\nHost: evil.com') + Traceback (most recent call last): + ... + ValueError: URL contains line breaks which are not allowed + + Space validation: + + >>> SecurityValidator.validate_url('https://exam ple.com') + Traceback (most recent call last): + ... + ValueError: URL contains spaces which are not allowed in URLs + >>> SecurityValidator.validate_url('https://example.com/path?query=hello world') + 'https://example.com/path?query=hello world' + + Malformed URLs: + + >>> SecurityValidator.validate_url('https://') + Traceback (most recent call last): + ... + ValueError: URL is not a valid URL + >>> SecurityValidator.validate_url('not-a-url') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + + Restricted IP addresses: + + >>> SecurityValidator.validate_url('https://0.0.0.0/') + Traceback (most recent call last): + ... + ValueError: URL contains invalid IP address (0.0.0.0) + >>> SecurityValidator.validate_url('https://169.254.169.254/') + Traceback (most recent call last): + ... + ValueError: URL contains restricted IP address + + Invalid port numbers: + + >>> SecurityValidator.validate_url('https://example.com:0/') + Traceback (most recent call last): + ... + ValueError: URL contains invalid port number + >>> try: + ... SecurityValidator.validate_url('https://example.com:65536/') + ... except ValueError as e: + ... 'Port out of range' in str(e) or 'invalid port' in str(e) + True + + Credentials in URL: + + >>> SecurityValidator.validate_url('https://user:pass@example.com/') + Traceback (most recent call last): + ... + ValueError: URL contains credentials which are not allowed + >>> SecurityValidator.validate_url('https://user@example.com/') + Traceback (most recent call last): + ... + ValueError: URL contains credentials which are not allowed + + XSS patterns in URLs: + + >>> SecurityValidator.validate_url('https://example.com/', 'test_field') + Traceback (most recent call last): + ... + ValueError: test_field contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'content') + Traceback (most recent call last): + ... + ValueError: content contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'data') + Traceback (most recent call last): + ... + ValueError: data contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'embed') + Traceback (most recent call last): + ... + ValueError: embed contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'style') + Traceback (most recent call last): + ... + ValueError: style contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'meta') + Traceback (most recent call last): + ... + ValueError: meta contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'base') + Traceback (most recent call last): + ... + ValueError: base contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('
', 'form') + Traceback (most recent call last): + ... + ValueError: form contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'image') + Traceback (most recent call last): + ... + ValueError: image contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'svg') + Traceback (most recent call last): + ... + ValueError: svg contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'video') + Traceback (most recent call last): + ... + ValueError: video contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'audio') + Traceback (most recent call last): + ... + ValueError: audio contains HTML tags that may cause security issues + """ + if not value: + return # Empty values are considered safe + # Check for dangerous HTML tags + if re.search(cls.DANGEROUS_HTML_PATTERN, value, re.IGNORECASE): + raise ValueError(f"{field_name} contains HTML tags that may cause security issues") + + @classmethod + def validate_json_depth( + cls, + obj: object, + max_depth: int | None = None, + current_depth: int = 0, + ) -> None: + """Validate that a JSON‑like structure does not exceed a depth limit. + + A *depth* is counted **only** when we enter a container (`dict` or + `list`). Primitive values (`str`, `int`, `bool`, `None`, etc.) do not + increase the depth, but an *empty* container still counts as one level. + + Args: + obj: Any Python object to inspect recursively. + max_depth: Maximum allowed depth (defaults to + :pyattr:`SecurityValidator.MAX_JSON_DEPTH`). + current_depth: Internal recursion counter. **Do not** set this + from user code. + + Raises: + ValueError: If the nesting level exceeds *max_depth*. + + Examples: + Simple flat dictionary – depth 1: :: + + >>> SecurityValidator.validate_json_depth({'name': 'Alice'}) + + Nested dict – depth 2: :: + + >>> SecurityValidator.validate_json_depth( + ... {'user': {'name': 'Alice'}} + ... ) + + Mixed dict/list – depth 3: :: + + >>> SecurityValidator.validate_json_depth( + ... {'users': [{'name': 'Alice', 'meta': {'age': 30}}]} + ... ) + + Exactly at the default limit (10) – allowed: :: + + >>> deep_10 = {'1': {'2': {'3': {'4': {'5': {'6': {'7': {'8': + ... {'9': {'10': 'end'}}}}}}}}}} + >>> SecurityValidator.validate_json_depth(deep_10) + + One level deeper – rejected: :: + + >>> deep_11 = {'1': {'2': {'3': {'4': {'5': {'6': {'7': {'8': + ... {'9': {'10': {'11': 'end'}}}}}}}}}}} + >>> SecurityValidator.validate_json_depth(deep_11) + Traceback (most recent call last): + ... + ValueError: JSON structure exceeds maximum depth of 10 + """ + if max_depth is None: + max_depth = cls.MAX_JSON_DEPTH + + # Only containers count toward depth; primitives are ignored + if not isinstance(obj, (dict, list)): + return + + next_depth = current_depth + 1 + if next_depth > max_depth: + raise ValueError(f"JSON structure exceeds maximum depth of {max_depth}") + + if isinstance(obj, dict): + for value in obj.values(): + cls.validate_json_depth(value, max_depth, next_depth) + else: # obj is a list + for item in obj: + cls.validate_json_depth(item, max_depth, next_depth) + + @classmethod + def validate_mime_type(cls, value: str) -> str: + """Validate MIME type format + + Args: + value (str): Value to validate + + Returns: + str: Value if acceptable + + Raises: + ValueError: When input is not acceptable + + Examples: + Empty/None handling: + + >>> SecurityValidator.validate_mime_type('') + '' + >>> SecurityValidator.validate_mime_type(None) #doctest: +SKIP + + Valid standard MIME types: + + >>> SecurityValidator.validate_mime_type('text/plain') + 'text/plain' + >>> SecurityValidator.validate_mime_type('application/json') + 'application/json' + >>> SecurityValidator.validate_mime_type('image/jpeg') + 'image/jpeg' + >>> SecurityValidator.validate_mime_type('text/html') + 'text/html' + >>> SecurityValidator.validate_mime_type('application/pdf') + 'application/pdf' + + Valid vendor-specific MIME types: + + >>> SecurityValidator.validate_mime_type('application/x-custom') + 'application/x-custom' + >>> SecurityValidator.validate_mime_type('text/x-log') + 'text/x-log' + + Valid MIME types with suffixes: + + >>> SecurityValidator.validate_mime_type('application/vnd.api+json') + 'application/vnd.api+json' + >>> SecurityValidator.validate_mime_type('image/svg+xml') + 'image/svg+xml' + + Invalid MIME type formats: + + >>> SecurityValidator.validate_mime_type('invalid') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('text/') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('/plain') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('text//plain') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('text/plain/extra') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('text plain') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + + Disallowed MIME types (not in whitelist - line 620): + + >>> try: + ... SecurityValidator.validate_mime_type('application/evil') + ... except ValueError as e: + ... 'not in the allowed list' in str(e) + True + >>> try: + ... SecurityValidator.validate_mime_type('text/evil') + ... except ValueError as e: + ... 'not in the allowed list' in str(e) + True + + Test MIME type with parameters (line 618): + + >>> try: + ... SecurityValidator.validate_mime_type('application/evil; charset=utf-8') + ... except ValueError as e: + ... 'Invalid MIME type format' in str(e) + True + """ + if not value: + return value + + # Basic MIME type pattern + mime_pattern = r"^[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*\/[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*$" + if not re.match(mime_pattern, value): + raise ValueError("Invalid MIME type format") + + # Common safe MIME types + safe_mime_types = settings.validation_allowed_mime_types + if value not in safe_mime_types: + # Allow x- vendor types and + suffixes + base_type = value.split(";")[0].strip() + if not (base_type.startswith("application/x-") or base_type.startswith("text/x-") or "+" in base_type): + raise ValueError(f"MIME type '{value}' is not in the allowed list") + + return value diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 7c82266ef..9856358d4 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -935,7 +935,7 @@ def parse_issuers(cls, v: Any) -> set[str]: # Plugin CLI settings plugins_cli_completion: bool = Field(default=False, description="Enable auto-completion for plugins CLI") - plugins_cli_markup_mode: str | None = Field(default=None, description="Set markup mode for plugins CLI") + plugins_cli_markup_mode: Literal["markdown", "rich", "disabled"] | None = Field(default=None, description="Set markup mode for plugins CLI") # Development dev_mode: bool = False diff --git a/mcpgateway/db.py b/mcpgateway/db.py index 46dfb5251..a6450a394 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -38,16 +38,16 @@ from sqlalchemy.pool import QueuePool # First-Party +from mcpgateway.common.validators import SecurityValidator from mcpgateway.config import settings from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.db_isready import wait_for_db_ready -from mcpgateway.validators import SecurityValidator logger = logging.getLogger(__name__) if TYPE_CHECKING: # First-Party - from mcpgateway.models import ResourceContent + from mcpgateway.common.models import ResourceContent # ResourceContent will be imported locally where needed to avoid circular imports # EmailUser models moved to this file to avoid circular imports @@ -2239,7 +2239,7 @@ def content(self) -> "ResourceContent": # Local import to avoid circular import # First-Party - from mcpgateway.models import ResourceContent # pylint: disable=import-outside-toplevel + from mcpgateway.common.models import ResourceContent # pylint: disable=import-outside-toplevel if self.text_content is not None: return ResourceContent( diff --git a/mcpgateway/federation/discovery.py b/mcpgateway/federation/discovery.py index e8d5409e0..c5d2890f7 100644 --- a/mcpgateway/federation/discovery.py +++ b/mcpgateway/federation/discovery.py @@ -78,8 +78,8 @@ # First-Party from mcpgateway import __version__ +from mcpgateway.common.models import ServerCapabilities from mcpgateway.config import settings -from mcpgateway.models import ServerCapabilities from mcpgateway.services.logging_service import LoggingService # Initialize logging service first diff --git a/mcpgateway/federation/forward.py b/mcpgateway/federation/forward.py index 4609cf311..cd3b106e4 100644 --- a/mcpgateway/federation/forward.py +++ b/mcpgateway/federation/forward.py @@ -36,11 +36,11 @@ from sqlalchemy.orm import Session # First-Party +from mcpgateway.common.models import ToolResult from mcpgateway.config import settings from mcpgateway.db import Gateway as DbGateway from mcpgateway.db import ServerMetric from mcpgateway.db import Tool as DbTool -from mcpgateway.models import ToolResult from mcpgateway.services.logging_service import LoggingService from mcpgateway.utils.passthrough_headers import get_passthrough_headers diff --git a/mcpgateway/handlers/sampling.py b/mcpgateway/handlers/sampling.py index ca0971d59..a30dce434 100644 --- a/mcpgateway/handlers/sampling.py +++ b/mcpgateway/handlers/sampling.py @@ -10,7 +10,7 @@ Examples: >>> import asyncio - >>> from mcpgateway.models import ModelPreferences + >>> from mcpgateway.common.models import ModelPreferences >>> handler = SamplingHandler() >>> asyncio.run(handler.initialize()) >>> @@ -48,7 +48,7 @@ from sqlalchemy.orm import Session # First-Party -from mcpgateway.models import CreateMessageResult, ModelPreferences, Role, TextContent +from mcpgateway.common.models import CreateMessageResult, ModelPreferences, Role, TextContent from mcpgateway.services.logging_service import LoggingService # Initialize logging service first @@ -247,7 +247,7 @@ def _select_model(self, preferences: ModelPreferences) -> str: SamplingError: If no suitable model found Examples: - >>> from mcpgateway.models import ModelPreferences, ModelHint + >>> from mcpgateway.common.models import ModelPreferences, ModelHint >>> handler = SamplingHandler() >>> >>> # Test intelligence priority diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 14b249568..559b6ad0a 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -63,6 +63,9 @@ from mcpgateway.auth import get_current_user from mcpgateway.bootstrap_db import main as bootstrap_db from mcpgateway.cache import ResourceCache, SessionRegistry +from mcpgateway.common.models import InitializeResult +from mcpgateway.common.models import JSONRPCError as PydanticJSONRPCError +from mcpgateway.common.models import ListResourceTemplatesResult, LogLevel, Root from mcpgateway.config import settings from mcpgateway.db import refresh_slugs_on_startup, SessionLocal from mcpgateway.db import Tool as DbTool @@ -72,9 +75,6 @@ from mcpgateway.middleware.request_logging_middleware import RequestLoggingMiddleware from mcpgateway.middleware.security_headers import SecurityHeadersMiddleware from mcpgateway.middleware.token_scoping import token_scoping_middleware -from mcpgateway.models import InitializeResult -from mcpgateway.models import JSONRPCError as PydanticJSONRPCError -from mcpgateway.models import ListResourceTemplatesResult, LogLevel, Root from mcpgateway.observability import init_telemetry from mcpgateway.plugins.framework import PluginError, PluginManager, PluginViolationError from mcpgateway.routers.well_known import router as well_known_router @@ -1840,7 +1840,7 @@ async def message_endpoint(request: Request, server_id: str, user=Depends(get_cu if request_id: # Try to complete the elicitation # First-Party - from mcpgateway.models import ElicitResult # pylint: disable=import-outside-toplevel + from mcpgateway.common.models import ElicitResult # pylint: disable=import-outside-toplevel from mcpgateway.services.elicitation_service import get_elicitation_service # pylint: disable=import-outside-toplevel elicitation_service = get_elicitation_service() @@ -2809,8 +2809,7 @@ async def read_resource(resource_id: str, request: Request, db: Session = Depend # Ensure a plain JSON-serializable structure try: # First-Party - # pylint: disable=import-outside-toplevel - from mcpgateway.models import ResourceContent, TextContent + from mcpgateway.common.models import ResourceContent, TextContent # pylint: disable=import-outside-toplevel # If already a ResourceContent, serialize directly if isinstance(content, ResourceContent): @@ -3824,7 +3823,7 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen # Validate params # First-Party - from mcpgateway.models import ElicitRequestParams # pylint: disable=import-outside-toplevel + from mcpgateway.common.models import ElicitRequestParams # pylint: disable=import-outside-toplevel from mcpgateway.services.elicitation_service import get_elicitation_service # pylint: disable=import-outside-toplevel try: diff --git a/mcpgateway/plugins/framework/__init__.py b/mcpgateway/plugins/framework/__init__.py index db61745c7..7783d788a 100644 --- a/mcpgateway/plugins/framework/__init__.py +++ b/mcpgateway/plugins/framework/__init__.py @@ -17,43 +17,47 @@ from mcpgateway.plugins.framework.base import Plugin from mcpgateway.plugins.framework.errors import PluginError, PluginViolationError from mcpgateway.plugins.framework.external.mcp.server import ExternalPluginServer +from mcpgateway.plugins.framework.hooks.registry import HookRegistry, get_hook_registry from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework.manager import PluginManager +from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload +from mcpgateway.plugins.framework.hooks.agents import AgentHookType, AgentPostInvokePayload, AgentPostInvokeResult, AgentPreInvokePayload, AgentPreInvokeResult +from mcpgateway.plugins.framework.hooks.resources import ResourceHookType, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, ResourcePreFetchResult +from mcpgateway.plugins.framework.hooks.prompts import ( + PromptHookType, + PromptPosthookPayload, + PromptPosthookResult, + PromptPrehookPayload, + PromptPrehookResult, +) +from mcpgateway.plugins.framework.hooks.tools import ToolHookType, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokeResult, ToolPreInvokePayload from mcpgateway.plugins.framework.models import ( GlobalContext, - HttpHeaderPayload, - HttpHeaderPayloadResult, - HookType, + MCPServerConfig, PluginCondition, PluginConfig, PluginContext, PluginErrorModel, PluginMode, + PluginPayload, PluginResult, PluginViolation, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - PromptResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, ) __all__ = [ + "AgentHookType", + "AgentPostInvokePayload", + "AgentPostInvokeResult", + "AgentPreInvokePayload", + "AgentPreInvokeResult", "ConfigLoader", "ExternalPluginServer", "GlobalContext", - "HookType", + "HookRegistry", "HttpHeaderPayload", - "HttpHeaderPayloadResult", + "get_hook_registry", + "MCPServerConfig", "Plugin", "PluginCondition", "PluginConfig", @@ -63,20 +67,23 @@ "PluginLoader", "PluginManager", "PluginMode", + "PluginPayload", "PluginResult", "PluginViolation", "PluginViolationError", + "PromptHookType", "PromptPosthookPayload", "PromptPosthookResult", "PromptPrehookPayload", "PromptPrehookResult", - "PromptResult", + "ResourceHookType", "ResourcePostFetchPayload", "ResourcePostFetchResult", "ResourcePreFetchPayload", "ResourcePreFetchResult", + "ToolHookType", "ToolPostInvokePayload", "ToolPostInvokeResult", - "ToolPreInvokePayload", "ToolPreInvokeResult", + "ToolPreInvokePayload", ] diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index 28bd25481..1d3e221b9 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -2,57 +2,45 @@ """Location: ./mcpgateway/plugins/framework/base.py Copyright 2025 SPDX-License-Identifier: Apache-2.0 -Authors: Teryl Taylor, Mihai Criveti +-Authors: Teryl Taylor, Mihai Criveti Base plugin implementation. This module implements the base plugin object. -It supports pre and post hooks AI safety, security and business processing -for the following locations in the server: -server_pre_register / server_post_register - for virtual server verification -tool_pre_invoke / tool_post_invoke - for guardrails -prompt_pre_fetch / prompt_post_fetch - for prompt filtering -resource_pre_fetch / resource_post_fetch - for content filtering -auth_pre_check / auth_post_check - for custom auth logic -federation_pre_sync / federation_post_sync - for gateway federation """ # Standard +from abc import ABC +from typing import Awaitable, Callable, Optional, Union import uuid # First-Party +from mcpgateway.plugins.framework.errors import PluginError from mcpgateway.plugins.framework.models import ( - HookType, PluginCondition, PluginConfig, PluginContext, + PluginErrorModel, PluginMode, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, + PluginPayload, + PluginResult, ) +# pylint: disable=import-outside-toplevel -class Plugin: + +class Plugin(ABC): """Base plugin object for pre/post processing of inputs and outputs at various locations throughout the server. Examples: - >>> from mcpgateway.plugins.framework import PluginConfig, HookType, PluginMode + >>> from mcpgateway.plugins.framework import PluginConfig, PluginMode + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> config = PluginConfig( ... name="test_plugin", ... description="Test plugin", ... author="test", ... kind="mcpgateway.plugins.framework.Plugin", ... version="1.0.0", - ... hooks=[HookType.PROMPT_PRE_FETCH], + ... hooks=[PromptHookType.PROMPT_PRE_FETCH], ... tags=["test"], ... mode=PluginMode.ENFORCE, ... priority=50 @@ -64,25 +52,35 @@ class Plugin: 50 >>> plugin.mode - >>> HookType.PROMPT_PRE_FETCH in plugin.hooks + >>> PromptHookType.PROMPT_PRE_FETCH in plugin.hooks True """ - def __init__(self, config: PluginConfig) -> None: + def __init__( + self, + config: PluginConfig, + hook_payloads: Optional[dict[str, PluginPayload]] = None, + hook_results: Optional[dict[str, PluginResult]] = None, + ) -> None: """Initialize a plugin with a configuration and context. Args: config: The plugin configuration + hook_payloads: optional mapping of hookpoints to payloads for the plugin. + Used for external plugins for converting json to pydantic. + hook_results: optional mapping of hookpoints to result types for the plugin. + Used for external plugins for converting json to pydantic. Examples: - >>> from mcpgateway.plugins.framework import PluginConfig, HookType + >>> from mcpgateway.plugins.framework import PluginConfig + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> config = PluginConfig( ... name="simple_plugin", ... description="Simple test", ... author="test", ... kind="test.Plugin", ... version="1.0.0", - ... hooks=[HookType.PROMPT_POST_FETCH], + ... hooks=[PromptHookType.PROMPT_POST_FETCH], ... tags=["simple"] ... ) >>> plugin = Plugin(config) @@ -90,6 +88,8 @@ def __init__(self, config: PluginConfig) -> None: 'simple_plugin' """ self._config = config + self._hook_payloads = hook_payloads + self._hook_results = hook_results @property def priority(self) -> int: @@ -128,7 +128,7 @@ def name(self) -> str: return self._config.name @property - def hooks(self) -> list[HookType]: + def hooks(self) -> list[str]: """Return the plugin's currently configured hooks. Returns: @@ -157,118 +157,93 @@ def conditions(self) -> list[PluginCondition] | None: async def initialize(self) -> None: """Initialize the plugin.""" - async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - """Plugin hook run before a prompt is retrieved and rendered. - - Args: - payload: The prompt payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'prompt_pre_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + async def shutdown(self) -> None: + """Plugin cleanup code.""" - async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: - """Plugin hook run after a prompt is rendered. + def json_to_payload(self, hook: str, payload: Union[str | dict]) -> PluginPayload: + """Converts a json payload to the proper pydantic payload object given a hook type. Used + mainly for serialization/deserialization of external plugin payloads. Args: - payload: The prompt payload to be analyzed. - context: Contextual information about the hook call. + hook: the hook type for which the payload needs converting. + payload: the payload as a string or dict. + + Returns: + A pydantic payload object corresponding to the hook type. Raises: - NotImplementedError: needs to be implemented by sub class. + PluginError: if no payload type is defined. """ - raise NotImplementedError( - f"""'prompt_post_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + hook_payload_type: type[PluginPayload] | None = None - async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: - """Plugin hook run before a tool is invoked. - - Args: - payload: The tool payload to be analyzed. - context: Contextual information about the hook call. + # First try instance-level hook_payloads + if self._hook_payloads: + hook_payload_type = self._hook_payloads.get(hook, None) # type: ignore[assignment] - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'tool_pre_invoke' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + # Fall back to global registry + if not hook_payload_type: + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry - async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: - """Plugin hook run after a tool is invoked. + registry = get_hook_registry() + hook_payload_type = registry.get_payload_type(hook) - Args: - payload: The tool result payload to be analyzed. - context: Contextual information about the hook call. + if not hook_payload_type: + raise PluginError(error=PluginErrorModel(message=f"No payload defined for hook {hook}.", plugin_name=self.name)) - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'tool_post_invoke' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + if isinstance(payload, str): + return hook_payload_type.model_validate_json(payload) + return hook_payload_type.model_validate(payload) - async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: - """Plugin hook run before a resource is fetched. + def json_to_result(self, hook: str, result: Union[str | dict]) -> PluginResult: + """Converts a json result to the proper pydantic result object given a hook type. Used + mainly for serialization/deserialization of external plugin results. Args: - payload: The resource payload to be analyzed. - context: Contextual information about the hook call. + hook: the hook type for which the result needs converting. + result: the result as a string or dict. + + Returns: + A pydantic result object corresponding to the hook type. Raises: - NotImplementedError: needs to be implemented by sub class. + PluginError: if no result type is defined. """ - raise NotImplementedError( - f"""'resource_pre_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + hook_result_type: type[PluginResult] | None = None - async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: - """Plugin hook run after a resource is fetched. + # First try instance-level hook_results + if self._hook_results: + hook_result_type = self._hook_results.get(hook, None) # type: ignore[assignment] - Args: - payload: The resource content payload to be analyzed. - context: Contextual information about the hook call. + # Fall back to global registry + if not hook_result_type: + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'resource_post_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + registry = get_hook_registry() + hook_result_type = registry.get_result_type(hook) - async def shutdown(self) -> None: - """Plugin cleanup code.""" + if not hook_result_type: + raise PluginError(error=PluginErrorModel(message=f"No result defined for hook {hook}.", plugin_name=self.name)) + + if isinstance(result, str): + return hook_result_type.model_validate_json(result) + return hook_result_type.model_validate(result) class PluginRef: """Plugin reference which contains a uuid. Examples: - >>> from mcpgateway.plugins.framework import PluginConfig, HookType, PluginMode + >>> from mcpgateway.plugins.framework import PluginConfig, PluginMode + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> config = PluginConfig( ... name="ref_test", ... description="Reference test", ... author="test", ... kind="test.Plugin", ... version="1.0.0", - ... hooks=[HookType.PROMPT_PRE_FETCH], + ... hooks=[PromptHookType.PROMPT_PRE_FETCH], ... tags=["ref", "test"], ... mode=PluginMode.PERMISSIVE, ... priority=100 @@ -294,14 +269,15 @@ def __init__(self, plugin: Plugin): plugin: The plugin to reference. Examples: - >>> from mcpgateway.plugins.framework import PluginConfig, HookType + >>> from mcpgateway.plugins.framework import PluginConfig + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> config = PluginConfig( ... name="plugin_ref", ... description="Test", ... author="test", ... kind="test.Plugin", ... version="1.0.0", - ... hooks=[HookType.PROMPT_POST_FETCH], + ... hooks=[PromptHookType.PROMPT_POST_FETCH], ... tags=[] ... ) >>> plugin = Plugin(config) @@ -351,7 +327,7 @@ def name(self) -> str: return self._plugin.name @property - def hooks(self) -> list[HookType]: + def hooks(self) -> list[str]: """Returns the plugin's currently configured hooks. Returns: @@ -385,3 +361,236 @@ def mode(self) -> PluginMode: Plugin's mode. """ return self.plugin.mode + + +class HookRef: + """A Hook reference point with plugin and function.""" + + def __init__(self, hook: str, plugin_ref: PluginRef): + """Initialize a hook reference point. + + Discovers the hook method using either: + 1. Convention-based naming (method name matches hook type) + 2. Decorator-based (@hook decorator with matching hook_type) + + Args: + hook: name of the hook point (e.g., 'tool_pre_invoke'). + plugin_ref: The reference to the plugin to hook. + + Raises: + PluginError: If no method is found for the specified hook. + + Examples: + >>> from mcpgateway.plugins.framework import PluginConfig + >>> config = PluginConfig(name="test", kind="test", version="1.0", author="test", hooks=["tool_pre_invoke"]) + >>> plugin = Plugin(config) + >>> plugin_ref = PluginRef(plugin) + >>> # This would work if plugin has tool_pre_invoke method or @hook("tool_pre_invoke") decorator + """ + # Standard + import inspect + + # First-Party + from mcpgateway.plugins.framework.decorator import get_hook_metadata + + self._plugin_ref = plugin_ref + self._hook = hook + + # Try convention-based lookup first (method name matches hook type) + self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] | None = getattr(plugin_ref.plugin, hook, None) + + # If not found by convention, scan for @hook decorated methods + if self._func is None: + for name, method in inspect.getmembers(plugin_ref.plugin, predicate=inspect.ismethod): + # Skip private/magic methods + if name.startswith("_"): + continue + + # Check for @hook decorator metadata + metadata = get_hook_metadata(method) + if metadata and metadata.hook_type == hook: + self._func = method + break + + # Raise error if hook method not found by either approach + if not self._func: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_ref.plugin.name}' has no hook: '{hook}'. " f"Method must either be named '{hook}' or decorated with @hook('{hook}')", + plugin_name=plugin_ref.plugin.name, + ) + ) + + # Validate hook method signature (parameter count and async) + self._validate_hook_signature(hook, self._func, plugin_ref.plugin.name) + + def _validate_hook_signature(self, hook: str, func: Callable, plugin_name: str) -> None: + """Validate that the hook method has the correct signature. + + Checks: + 1. Method accepts correct number of parameters (self, payload, context) + 2. Method is async (returns coroutine) + + Args: + hook: The hook type being validated + func: The hook method to validate + plugin_name: Name of the plugin (for error messages) + + Raises: + PluginError: If the signature is invalid + """ + # Standard + import inspect + + sig = inspect.signature(func) + params = list(sig.parameters.values()) + + # Check parameter count (should be: payload, context) + # Note: 'self' is not included in bound method signatures + if len(params) != 2: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' has invalid signature. " + f"Expected 2 parameters (payload, context), got {len(params)}: {list(sig.parameters.keys())}. " + f"Correct signature: async def {hook}(self, payload: PayloadType, context: PluginContext) -> ResultType", + plugin_name=plugin_name, + ) + ) + + # Check that method is async + if not inspect.iscoroutinefunction(func): + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' must be async. " + f"Method '{func.__name__}' is not a coroutine function. " + f"Use 'async def {func.__name__}(...)' instead of 'def {func.__name__}(...)'.", + plugin_name=plugin_name, + ) + ) + + # ========== OPTIONAL: Type Hint Validation ========== + # Uncomment to enable strict type checking of payload and return types. + # This validates that type hints match the expected types from the hook registry. + # Pros: Catches type errors at plugin load time instead of runtime + # Cons: Requires all plugins to have type hints, adds validation overhead + # + # self._validate_type_hints(hook, func, params, plugin_name) + + def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_name: str) -> None: + """Validate that type hints match expected payload and result types. + + This is an optional validation that can be enabled to enforce type safety. + + Args: + hook: The hook type being validated + func: The hook method to validate + params: List of function parameters + plugin_name: Name of the plugin (for error messages) + + Raises: + PluginError: If type hints are missing or don't match expected types + """ + # Standard + from typing import get_type_hints + + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry + + # Get expected types from registry + registry = get_hook_registry() + expected_payload_type = registry.get_payload_type(hook) + expected_result_type = registry.get_result_type(hook) + + # If hook is not registered in global registry, we can't validate types + if not expected_payload_type or not expected_result_type: + return + + # Get type hints from the function + try: + hints = get_type_hints(func) + except Exception as e: + # Type hints might use forward references or unavailable types + # We'll skip validation rather than fail + # Standard + import logging + + logger = logging.getLogger(__name__) + logger.debug("Could not extract type hints for plugin '%s' hook '%s': %s", plugin_name, hook, e) + return + + # Validate payload parameter type (first parameter, since 'self' is not in params) + payload_param_name = params[0].name + if payload_param_name not in hints: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' missing type hint for parameter '{payload_param_name}'. " f"Expected: {payload_param_name}: {expected_payload_type.__name__}", + plugin_name=plugin_name, + ) + ) + + actual_payload_type = hints[payload_param_name] + + # Check if types match (exact match or subclass) + if actual_payload_type != expected_payload_type: + # Check for generic types or complex type hints + actual_type_str = str(actual_payload_type) + expected_type_str = expected_payload_type.__name__ + + # If the expected type name is in the string representation, it's probably OK + if expected_type_str not in actual_type_str: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' parameter '{payload_param_name}' " f"has incorrect type hint. Expected: {expected_type_str}, Got: {actual_type_str}", + plugin_name=plugin_name, + ) + ) + + # Validate return type + if "return" not in hints: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' missing return type hint. " f"Expected: -> {expected_result_type.__name__}", + plugin_name=plugin_name, + ) + ) + + actual_return_type = hints["return"] + return_type_str = str(actual_return_type) + expected_return_str = expected_result_type.__name__ + + # For async functions, the return type might be wrapped in Coroutine or Awaitable + # We just check if the expected type is mentioned in the return type + if expected_return_str not in return_type_str and actual_return_type != expected_result_type: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' has incorrect return type hint. " f"Expected: {expected_return_str}, Got: {return_type_str}", + plugin_name=plugin_name, + ) + ) + + @property + def plugin_ref(self) -> PluginRef: + """The reference to the plugin object. + + Returns: + A plugin reference. + """ + return self._plugin_ref + + @property + def name(self) -> str: + """The name of the hooking function. + + Returns: + A plugin name. + """ + return self._hook + + @property + def hook(self) -> Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] | None: + """The hooking function that can be invoked within the reference. + + Returns: + An awaitable hook function reference. + """ + return self._func diff --git a/mcpgateway/plugins/framework/constants.py b/mcpgateway/plugins/framework/constants.py index 155679c57..7c3d81e90 100644 --- a/mcpgateway/plugins/framework/constants.py +++ b/mcpgateway/plugins/framework/constants.py @@ -16,7 +16,6 @@ PYTHON_SUFFIX = ".py" URL = "url" SCRIPT = "script" -AFTER = "after" NAME = "name" PYTHON = "python" @@ -25,7 +24,6 @@ CONTEXT = "context" RESULT = "result" ERROR = "error" -GET_PLUGIN_CONFIG = "get_plugin_config" IGNORE_CONFIG_EXTERNAL = "ignore_config_external" # Global Context Metadata fields @@ -37,3 +35,6 @@ MCP_SERVER_NAME = "MCP Plugin Server" MCP_SERVER_INSTRUCTIONS = "External plugin server for MCP Gateway" GET_PLUGIN_CONFIGS = "get_plugin_configs" +GET_PLUGIN_CONFIG = "get_plugin_config" +HOOK_TYPE = "hook_type" +INVOKE_HOOK = "invoke_hook" diff --git a/mcpgateway/plugins/framework/decorator.py b/mcpgateway/plugins/framework/decorator.py new file mode 100644 index 000000000..2bd998618 --- /dev/null +++ b/mcpgateway/plugins/framework/decorator.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/decorator.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Hook decorator for dynamically registering plugin hooks. + +This module provides decorators for marking plugin methods as hook handlers. +Plugins can use these decorators to: +1. Override the default hook naming convention +2. Register custom hooks not in the standard framework + +Examples: + Override hook method name:: + + class MyPlugin(Plugin): + @hook(ToolHookType.TOOL_PRE_INVOKE) + def custom_name_for_tool_hook(self, payload, context): + # This gets called for tool_pre_invoke even though + # the method name doesn't match + return ToolPreInvokeResult(continue_processing=True) + + Register a completely new hook type:: + + class MyPlugin(Plugin): + @hook("custom_pre_process", CustomPayload, CustomResult) + def my_custom_hook(self, payload, context): + # This registers a new hook type dynamically + return CustomResult(continue_processing=True) + + Use default convention (no decorator needed):: + + class MyPlugin(Plugin): + def tool_pre_invoke(self, payload, context): + # Automatically recognized by naming convention + return ToolPreInvokeResult(continue_processing=True) +""" + +# Standard +from typing import Callable, Optional, Type, TypeVar + +# Third-Party +from pydantic import BaseModel + +# First-Party +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + +# Attribute name for storing hook metadata on functions +_HOOK_METADATA_ATTR = "_plugin_hook_metadata" + +# Type vars for type hints +P = TypeVar("P", bound=PluginPayload) # Payload type +R = TypeVar("R", bound=PluginResult) # Result type + + +class HookMetadata: + """Metadata stored on decorated hook methods. + + Attributes: + hook_type: The hook type identifier (e.g., 'tool_pre_invoke') + payload_type: Optional payload class for hook registration + result_type: Optional result class for hook registration + """ + + def __init__( + self, + hook_type: str, + payload_type: Optional[Type[BaseModel]] = None, + result_type: Optional[Type[BaseModel]] = None, + ): + """Initialize hook metadata. + + Args: + hook_type: The hook type identifier + payload_type: Optional payload class for registering new hooks + result_type: Optional result class for registering new hooks + """ + self.hook_type = hook_type + self.payload_type = payload_type + self.result_type = result_type + + +def hook( + hook_type: str, + payload_type: Optional[Type[P]] = None, + result_type: Optional[Type[R]] = None, +) -> Callable[[Callable], Callable]: + """Decorator to mark a method as a plugin hook handler. + + This decorator attaches metadata to a method so the Plugin class can + discover it during initialization and register it with the appropriate + hook type. + + Args: + hook_type: The hook type identifier (e.g., 'tool_pre_invoke') + payload_type: Optional payload class for registering new hook types + result_type: Optional result class for registering new hook types + + Returns: + Decorator function that marks the method with hook metadata + + Examples: + Override method name:: + + @hook(ToolHookType.TOOL_PRE_INVOKE) + def my_custom_method_name(self, payload, context): + return ToolPreInvokeResult(continue_processing=True) + + Register new hook type:: + + @hook("email_pre_send", EmailPayload, EmailResult) + def handle_email(self, payload, context): + return EmailResult(continue_processing=True) + """ + + def decorator(func: Callable) -> Callable: + """Inner decorator that attaches metadata to the function. + + Args: + func: The function to decorate + + Returns: + The same function with metadata attached + """ + # Store metadata on the function object + metadata = HookMetadata(hook_type, payload_type, result_type) + setattr(func, _HOOK_METADATA_ATTR, metadata) + return func + + return decorator + + +def get_hook_metadata(func: Callable) -> Optional[HookMetadata]: + """Get hook metadata from a decorated function. + + Args: + func: The function to check + + Returns: + HookMetadata if the function is decorated, None otherwise + + Examples: + >>> @hook("test_hook") + ... def test_func(): + ... pass + >>> metadata = get_hook_metadata(test_func) + >>> metadata.hook_type + 'test_hook' + >>> get_hook_metadata(lambda: None) is None + True + """ + return getattr(func, _HOOK_METADATA_ATTR, None) + + +def has_hook_metadata(func: Callable) -> bool: + """Check if a function has hook metadata. + + Args: + func: The function to check + + Returns: + True if the function is decorated with @hook, False otherwise + + Examples: + >>> @hook("test_hook") + ... def decorated(): + ... pass + >>> has_hook_metadata(decorated) + True + >>> has_hook_metadata(lambda: None) + False + """ + return hasattr(func, _HOOK_METADATA_ATTR) diff --git a/mcpgateway/plugins/framework/external/mcp/client.py b/mcpgateway/plugins/framework/external/mcp/client.py index 1d8e60133..465a0b81a 100644 --- a/mcpgateway/plugins/framework/external/mcp/client.py +++ b/mcpgateway/plugins/framework/external/mcp/client.py @@ -11,45 +11,47 @@ # Standard import asyncio from contextlib import AsyncExitStack +from functools import partial import json import logging import os -from typing import Any, Optional, Type, TypeVar +from typing import Any, Awaitable, Callable, Optional # Third-Party import httpx from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client -from pydantic import BaseModel +from mcp.types import TextContent # First-Party -from mcpgateway.plugins.framework.base import Plugin -from mcpgateway.plugins.framework.constants import CONTEXT, ERROR, GET_PLUGIN_CONFIG, IGNORE_CONFIG_EXTERNAL, NAME, PAYLOAD, PLUGIN_NAME, PYTHON, PYTHON_SUFFIX, RESULT +from mcpgateway.common.models import TransportType +from mcpgateway.plugins.framework.base import HookRef, Plugin, PluginRef +from mcpgateway.plugins.framework.constants import ( + CONTEXT, + ERROR, + GET_PLUGIN_CONFIG, + HOOK_TYPE, + IGNORE_CONFIG_EXTERNAL, + INVOKE_HOOK, + NAME, + PAYLOAD, + PLUGIN_NAME, + PYTHON, + PYTHON_SUFFIX, + RESULT, +) from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError from mcpgateway.plugins.framework.external.mcp.tls_utils import create_ssl_context +from mcpgateway.plugins.framework.hooks.registry import get_hook_registry from mcpgateway.plugins.framework.models import ( - HookType, MCPClientTLSConfig, PluginConfig, PluginContext, PluginErrorModel, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, + PluginPayload, + PluginResult, ) -from mcpgateway.schemas import TransportType - -P = TypeVar("P", bound=BaseModel) logger = logging.getLogger(__name__) @@ -81,8 +83,12 @@ async def initialize(self) -> None: if not self._config.mcp: raise PluginError(error=PluginErrorModel(message="The mcp section must be defined for external plugin", plugin_name=self.name)) if self._config.mcp.proto == TransportType.STDIO: + if not self._config.mcp.script: + raise PluginError(error=PluginErrorModel(message="STDIO transport requires script", plugin_name=self.name)) await self.__connect_to_stdio_server(self._config.mcp.script) elif self._config.mcp.proto == TransportType.STREAMABLEHTTP: + if not self._config.mcp.url: + raise PluginError(error=PluginErrorModel(message="STREAMABLEHTTP transport requires url", plugin_name=self.name)) await self.__connect_to_http_server(self._config.mcp.url) try: @@ -146,9 +152,6 @@ async def __connect_to_http_server(self, uri: str) -> None: Raises: PluginError: if there is an external connection error after all retries. """ - max_retries = 3 - base_delay = 1.0 - plugin_tls = self._config.mcp.tls if self._config and self._config.mcp else None tls_config = plugin_tls or MCPClientTLSConfig.from_env() @@ -188,37 +191,37 @@ def _tls_httpx_client_factory( return httpx.AsyncClient(**kwargs) + max_retries = 3 + base_delay = 1.0 + for attempt in range(max_retries): - logger.info(f"Connecting to external plugin server: {uri} (attempt {attempt + 1}/{max_retries})") try: - # Create a fresh exit stack for each attempt + client_factory = _tls_httpx_client_factory if tls_config else None async with AsyncExitStack() as temp_stack: - client_factory = _tls_httpx_client_factory if tls_config else None streamable_client = streamablehttp_client(uri, httpx_client_factory=client_factory) if client_factory else streamablehttp_client(uri) http_transport = await temp_stack.enter_async_context(streamable_client) http_client, write_func, _ = http_transport session = await temp_stack.enter_async_context(ClientSession(http_client, write_func)) - await session.initialize() - # List available tools response = await session.list_tools() tools = response.tools - logger.info("Successfully connected to plugin MCP server with tools: %s", " ".join([tool.name for tool in tools])) + logger.info( + "Successfully connected to plugin MCP server with tools: %s", + " ".join([tool.name for tool in tools]), + ) - # Success! Now move to the main exit stack client_factory = _tls_httpx_client_factory if tls_config else None streamable_client = streamablehttp_client(uri, httpx_client_factory=client_factory) if client_factory else streamablehttp_client(uri) http_transport = await self._exit_stack.enter_async_context(streamable_client) self._http, self._write, _ = http_transport self._session = await self._exit_stack.enter_async_context(ClientSession(self._http, self._write)) + await self._session.initialize() return - except Exception as e: logger.warning(f"Connection attempt {attempt + 1}/{max_retries} failed: {e}") - if attempt == max_retries - 1: # Final attempt failed error_msg = f"External plugin '{self.name}' connection failed after {max_retries} attempts: {uri} is not reachable. Please ensure the MCP server is running." @@ -230,12 +233,11 @@ def _tls_httpx_client_factory( logger.info(f"Retrying in {delay}s...") await asyncio.sleep(delay) - async def __invoke_hook(self, payload_result_model: Type[P], hook_type: HookType, payload: BaseModel, context: PluginContext) -> P: + async def invoke_hook(self, hook_type: str, payload: PluginPayload, context: PluginContext) -> PluginResult: """Invoke an external plugin hook using the MCP protocol. Args: - payload_result_model: The type of result payload for the hook. - hook_type: The type of hook invoked (i.e., prompt_pre_hook) + hook_type: The type of hook invoked (i.e., prompt_pre_fetch) payload: The payload to be passed to the hook. context: The plugin context passed to the run. @@ -245,18 +247,31 @@ async def __invoke_hook(self, payload_result_model: Type[P], hook_type: HookType Returns: The resulting payload from the plugin. """ + # Get the result type from the global registry + registry = get_hook_registry() + result_type = registry.get_result_type(hook_type) + if not result_type: + raise PluginError(error=PluginErrorModel(message=f"Hook type '{hook_type}' not registered in hook registry", plugin_name=self.name)) + + if not self._session: + raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name)) try: - result = await self._session.call_tool(hook_type, {PLUGIN_NAME: self.name, PAYLOAD: payload, CONTEXT: context}) + result = await self._session.call_tool(INVOKE_HOOK, {HOOK_TYPE: hook_type, PLUGIN_NAME: self.name, PAYLOAD: payload, CONTEXT: context}) for content in result.content: - res = json.loads(content.text) + if not isinstance(content, TextContent): + continue + try: + res = json.loads(content.text) + except json.decoder.JSONDecodeError: + raise PluginError(error=PluginErrorModel(message=f"Error trying to decode json: {content.text}", code="JSON_DECODE_ERROR", plugin_name=self.name)) if CONTEXT in res: cxt = PluginContext.model_validate(res[CONTEXT]) context.state = cxt.state context.metadata = cxt.metadata context.global_context.state = cxt.global_context.state if RESULT in res: - return payload_result_model.model_validate(res[RESULT]) + return result_type.model_validate(res[RESULT]) if ERROR in res: error = PluginErrorModel.model_validate(res[ERROR]) raise PluginError(error) @@ -268,83 +283,6 @@ async def __invoke_hook(self, payload_result_model: Type[P], hook_type: HookType raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) raise PluginError(error=PluginErrorModel(message=f"Received invalid response. Result = {result}", plugin_name=self.name)) - async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - """Plugin hook run before a prompt is retrieved and rendered. - - Args: - payload: The prompt payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Returns: - The prompt prehook with name and arguments as modified or blocked by the plugin. - """ - - return await self.__invoke_hook(payload_result_model=PromptPrehookResult, hook_type=HookType.PROMPT_PRE_FETCH, payload=payload, context=context) - - async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: - """Plugin hook run after a prompt is rendered. - - Args: - payload: The prompt payload to be analyzed. - context: Contextual information about the hook call. - - Returns: - A set of prompt messages as modified or blocked by the plugin. - """ - return await self.__invoke_hook(payload_result_model=PromptPosthookResult, hook_type=HookType.PROMPT_POST_FETCH, payload=payload, context=context) - - async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: - """Plugin hook run before a tool is invoked. - - Args: - payload: The tool payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Returns: - The tool prehook with name and arguments as modified or blocked by the plugin. - """ - - return await self.__invoke_hook(payload_result_model=ToolPreInvokeResult, hook_type=HookType.TOOL_PRE_INVOKE, payload=payload, context=context) - - async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: - """Plugin hook run after a tool is invoked. - - Args: - payload: The tool payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Returns: - The tool posthook with name and arguments as modified or blocked by the plugin. - """ - - return await self.__invoke_hook(payload_result_model=ToolPostInvokeResult, hook_type=HookType.TOOL_POST_INVOKE, payload=payload, context=context) - - async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: - """Plugin hook run before a resource is fetched. - - Args: - payload: The resource payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Returns: - The resource prehook with name and arguments as modified or blocked by the plugin. - """ - - return await self.__invoke_hook(payload_result_model=ResourcePreFetchResult, hook_type=HookType.RESOURCE_PRE_FETCH, payload=payload, context=context) - - async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: - """Plugin hook run after a resource is fetched. - - Args: - payload: The resource payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Returns: - The resource posthook with name and arguments as modified or blocked by the plugin. - """ - - return await self.__invoke_hook(payload_result_model=ResourcePostFetchResult, hook_type=HookType.RESOURCE_POST_FETCH, payload=payload, context=context) - async def __get_plugin_config(self) -> PluginConfig | None: """Retrieve plugin configuration for the current plugin on the remote MCP server. @@ -354,9 +292,13 @@ async def __get_plugin_config(self) -> PluginConfig | None: Returns: A plugin configuration for the current plugin from a remote MCP server. """ + if not self._session: + raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name)) try: configs = await self._session.call_tool(GET_PLUGIN_CONFIG, {NAME: self.name}) for content in configs.content: + if not isinstance(content, TextContent): + continue conf = json.loads(content.text) return PluginConfig.model_validate(conf) except Exception as e: @@ -369,3 +311,27 @@ async def shutdown(self) -> None: """Plugin cleanup code.""" if self._exit_stack: await self._exit_stack.aclose() + + +class ExternalHookRef(HookRef): + """A Hook reference point for external plugins.""" + + def __init__(self, hook: str, plugin_ref: PluginRef): # pylint: disable=super-init-not-called + """Initialize a hook reference point for an external plugin. + + Note: We intentionally don't call super().__init__() because external plugins + use invoke_hook() rather than direct method attributes. + + Args: + hook: name of the hook point. + plugin_ref: The reference to the plugin to hook. + + Raises: + PluginError: If the plugin is not an external plugin. + """ + self._plugin_ref = plugin_ref + self._hook = hook + if hasattr(plugin_ref.plugin, INVOKE_HOOK): + self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = partial(plugin_ref.plugin.invoke_hook, hook) # type: ignore[attr-defined] + else: + raise PluginError(error=PluginErrorModel(message=f"Plugin: {plugin_ref.plugin.name} is not an external plugin", plugin_name=plugin_ref.plugin.name)) diff --git a/mcpgateway/plugins/framework/external/mcp/server/runtime.py b/mcpgateway/plugins/framework/external/mcp/server/runtime.py old mode 100755 new mode 100644 index 09b3a2ed1..fcf1e6507 --- a/mcpgateway/plugins/framework/external/mcp/server/runtime.py +++ b/mcpgateway/plugins/framework/external/mcp/server/runtime.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # -*- coding: utf-8 -*- """Location: ./mcpgateway/plugins/framework/external/mcp/server/runtime.py Copyright 2025 @@ -29,32 +28,19 @@ # First-Party from mcpgateway.plugins.framework import ( ExternalPluginServer, - Plugin, - PluginContext, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, + MCPServerConfig, ) from mcpgateway.plugins.framework.constants import ( GET_PLUGIN_CONFIG, GET_PLUGIN_CONFIGS, + INVOKE_HOOK, MCP_SERVER_INSTRUCTIONS, MCP_SERVER_NAME, ) -from mcpgateway.plugins.framework.models import HookType, MCPServerConfig logger = logging.getLogger(__name__) -SERVER: ExternalPluginServer = None +SERVER: ExternalPluginServer | None = None # Module-level tool functions (extracted for testability) @@ -65,7 +51,12 @@ async def get_plugin_configs() -> list[dict]: Returns: JSON string containing list of plugin configuration dictionaries. + + Raises: + RuntimeError: If plugin server not initialized. """ + if not SERVER: + raise RuntimeError("Plugin server not initialized") return await SERVER.get_plugin_configs() @@ -77,176 +68,36 @@ async def get_plugin_config(name: str) -> dict: Returns: JSON string containing plugin configuration dictionary. - """ - return await SERVER.get_plugin_config(name) - - -async def prompt_pre_fetch(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute prompt prefetch hook for a plugin. - - Args: - plugin_name: The name of the plugin to execute - payload: The prompt name and arguments to be analyzed - context: Contextual information required for execution - - Returns: - Result dictionary from the prompt prefetch hook. - """ - - def prompt_pre_fetch_func(plugin: Plugin, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - """Wrapper function to invoke prompt prefetch on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The prompt prehook payload. - context: The plugin context. - - Returns: - Result from the plugin's prompt_pre_fetch method. - """ - return plugin.prompt_pre_fetch(payload, context) - - return await SERVER.invoke_hook(PromptPrehookPayload, prompt_pre_fetch_func, plugin_name, payload, context) - - -async def prompt_post_fetch(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute prompt postfetch hook for a plugin. - - Args: - plugin_name: The name of the plugin to execute - payload: The prompt payload to be analyzed - context: Contextual information - - Returns: - Result dictionary from the prompt postfetch hook. - """ - - def prompt_post_fetch_func(plugin: Plugin, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: - """Wrapper function to invoke prompt postfetch on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The prompt posthook payload. - context: The plugin context. - - Returns: - Result from the plugin's prompt_post_fetch method. - """ - return plugin.prompt_post_fetch(payload, context) - - return await SERVER.invoke_hook(PromptPosthookPayload, prompt_post_fetch_func, plugin_name, payload, context) - - -async def tool_pre_invoke(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute tool pre-invoke hook for a plugin. - - Args: - plugin_name: The name of the plugin to execute - payload: The tool name and arguments to be analyzed - context: Contextual information - - Returns: - Result dictionary from the tool pre-invoke hook. - """ - - def tool_pre_invoke_func(plugin: Plugin, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: - """Wrapper function to invoke tool pre-invoke on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The tool pre-invoke payload. - context: The plugin context. - Returns: - Result from the plugin's tool_pre_invoke method. - """ - return plugin.tool_pre_invoke(payload, context) - - return await SERVER.invoke_hook(ToolPreInvokePayload, tool_pre_invoke_func, plugin_name, payload, context) - - -async def tool_post_invoke(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute tool post-invoke hook for a plugin. - - Args: - plugin_name: The name of the plugin to execute - payload: The tool result to be analyzed - context: Contextual information - - Returns: - Result dictionary from the tool post-invoke hook. - """ - - def tool_post_invoke_func(plugin: Plugin, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: - """Wrapper function to invoke tool post-invoke on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The tool post-invoke payload. - context: The plugin context. - - Returns: - Result from the plugin's tool_post_invoke method. - """ - return plugin.tool_post_invoke(payload, context) - - return await SERVER.invoke_hook(ToolPostInvokePayload, tool_post_invoke_func, plugin_name, payload, context) - - -async def resource_pre_fetch(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute resource prefetch hook for a plugin. - - Args: - plugin_name: The name of the plugin to execute - payload: The resource name and arguments to be analyzed - context: Contextual information - - Returns: - Result dictionary from the resource prefetch hook. + Raises: + RuntimeError: If plugin server not initialized. """ - - def resource_pre_fetch_func(plugin: Plugin, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: - """Wrapper function to invoke resource prefetch on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The resource prefetch payload. - context: The plugin context. - - Returns: - Result from the plugin's resource_pre_fetch method. - """ - return plugin.resource_pre_fetch(payload, context) - - return await SERVER.invoke_hook(ResourcePreFetchPayload, resource_pre_fetch_func, plugin_name, payload, context) + if not SERVER: + raise RuntimeError("Plugin server not initialized") + result = await SERVER.get_plugin_config(name) + if result is None: + return {} + return result -async def resource_post_fetch(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute resource postfetch hook for a plugin. +async def invoke_hook(hook_type: str, plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: + """Execute a hook for a plugin. Args: + hook_type: The name or type of the hook. plugin_name: The name of the plugin to execute payload: The resource payload to be analyzed context: Contextual information Returns: - Result dictionary from the resource postfetch hook. - """ - - def resource_post_fetch_func(plugin: Plugin, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: - """Wrapper function to invoke resource postfetch on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The resource postfetch payload. - context: The plugin context. + Result dictionary with payload, context and any error information. - Returns: - Result from the plugin's resource_post_fetch method. - """ - return plugin.resource_post_fetch(payload, context) - - return await SERVER.invoke_hook(ResourcePostFetchPayload, resource_post_fetch_func, plugin_name, payload, context) + Raises: + RuntimeError: If plugin server not initialized. + """ + if not SERVER: + raise RuntimeError("Plugin server not initialized") + return await SERVER.invoke_hook(hook_type, plugin_name, payload, context) class SSLCapableFastMCP(FastMCP): @@ -288,7 +139,7 @@ def _get_ssl_config(self) -> dict: if tls.ca_bundle: ssl_config["ssl_ca_certs"] = tls.ca_bundle - ssl_config["ssl_cert_reqs"] = tls.ssl_cert_reqs + ssl_config["ssl_cert_reqs"] = str(tls.ssl_cert_reqs) if tls.keyfile_password: ssl_config["ssl_keyfile_password"] = tls.keyfile_password @@ -320,12 +171,9 @@ async def _start_health_check_server(self, health_port: int) -> None: from starlette.responses import JSONResponse # pylint: disable=import-outside-toplevel from starlette.routing import Route # pylint: disable=import-outside-toplevel - async def health_check(request: Request): # pylint: disable=unused-argument + async def health_check(_request: Request): """Health check endpoint for container orchestration. - Args: - request: the http request from which the health check occurs. - Returns: JSON response with health status. """ @@ -354,12 +202,9 @@ async def run_streamable_http_async(self) -> None: from starlette.responses import JSONResponse # pylint: disable=import-outside-toplevel from starlette.routing import Route # pylint: disable=import-outside-toplevel - async def health_check(request: Request): # pylint: disable=unused-argument + async def health_check(_request: Request): """Health check endpoint for container orchestration. - Args: - request: the http request from which the health check occurs. - Returns: JSON response with health status. """ @@ -379,7 +224,7 @@ async def health_check(request: Request): # pylint: disable=unused-argument config_kwargs.update(ssl_config) logger.info(f"Starting plugin server on {self.settings.host}:{self.settings.port}") - config = uvicorn.Config(**config_kwargs) + config = uvicorn.Config(**config_kwargs) # type: ignore[arg-type] server = uvicorn.Server(config) # If SSL is enabled, start a separate HTTP health check server @@ -393,7 +238,7 @@ async def health_check(request: Request): # pylint: disable=unused-argument await server.serve() -async def run(): +async def run() -> None: """Run the external plugin server with FastMCP. Supports both stdio and HTTP transports. Auto-detects transport based on stdin @@ -445,12 +290,7 @@ async def run(): # Register module-level tool functions with FastMCP mcp.tool(name=GET_PLUGIN_CONFIGS)(get_plugin_configs) mcp.tool(name=GET_PLUGIN_CONFIG)(get_plugin_config) - mcp.tool(name=HookType.PROMPT_PRE_FETCH.value)(prompt_pre_fetch) - mcp.tool(name=HookType.PROMPT_POST_FETCH.value)(prompt_post_fetch) - mcp.tool(name=HookType.TOOL_PRE_INVOKE.value)(tool_pre_invoke) - mcp.tool(name=HookType.TOOL_POST_INVOKE.value)(tool_post_invoke) - mcp.tool(name=HookType.RESOURCE_PRE_FETCH.value)(resource_pre_fetch) - mcp.tool(name=HookType.RESOURCE_POST_FETCH.value)(resource_post_fetch) + mcp.tool(name=INVOKE_HOOK)(invoke_hook) # Run with stdio transport logger.info("Starting MCP plugin server with FastMCP (stdio transport)") @@ -467,12 +307,7 @@ async def run(): # Register module-level tool functions with FastMCP mcp.tool(name=GET_PLUGIN_CONFIGS)(get_plugin_configs) mcp.tool(name=GET_PLUGIN_CONFIG)(get_plugin_config) - mcp.tool(name=HookType.PROMPT_PRE_FETCH.value)(prompt_pre_fetch) - mcp.tool(name=HookType.PROMPT_POST_FETCH.value)(prompt_post_fetch) - mcp.tool(name=HookType.TOOL_PRE_INVOKE.value)(tool_pre_invoke) - mcp.tool(name=HookType.TOOL_POST_INVOKE.value)(tool_post_invoke) - mcp.tool(name=HookType.RESOURCE_PRE_FETCH.value)(resource_pre_fetch) - mcp.tool(name=HookType.RESOURCE_POST_FETCH.value)(resource_post_fetch) + mcp.tool(name=INVOKE_HOOK)(invoke_hook) # Run with streamable-http transport logger.info("Starting MCP plugin server with FastMCP (HTTP transport)") diff --git a/mcpgateway/plugins/framework/external/mcp/server/server.py b/mcpgateway/plugins/framework/external/mcp/server/server.py index 78dba8ce9..adf8036fe 100644 --- a/mcpgateway/plugins/framework/external/mcp/server/server.py +++ b/mcpgateway/plugins/framework/external/mcp/server/server.py @@ -2,34 +2,27 @@ """Location: ./mcpgateway/plugins/framework/external/mcp/server/server.py Copyright 2025 SPDX-License-Identifier: Apache-2.0 -Authors: Teryl Taylor - -Plugin MCP Server. - Fred Araujo +Authors: Fred Araujo, Teryl Taylor Module that contains plugin MCP server code to serve external plugins. """ # Standard -import asyncio import logging import os -from typing import Any, Callable, Dict, Type, TypeVar +from typing import Any, Dict, TypeVar # Third-Party from pydantic import BaseModel # First-Party -from mcpgateway.plugins.framework.base import Plugin from mcpgateway.plugins.framework.constants import CONTEXT, ERROR, PLUGIN_NAME, RESULT -from mcpgateway.plugins.framework.errors import convert_exception_to_error +from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError from mcpgateway.plugins.framework.loader.config import ConfigLoader -from mcpgateway.plugins.framework.manager import DEFAULT_PLUGIN_TIMEOUT, PluginManager +from mcpgateway.plugins.framework.manager import PluginManager from mcpgateway.plugins.framework.models import ( MCPServerConfig, PluginContext, - PluginErrorModel, - PluginResult, ) P = TypeVar("P", bound=BaseModel) @@ -70,18 +63,19 @@ async def get_plugin_configs(self) -> list[dict]: True """ plugins: list[dict] = [] - for plug in self._config.plugins: - plugins.append(plug.model_dump()) + if self._config.plugins: + for plug in self._config.plugins: + plugins.append(plug.model_dump()) return plugins - async def get_plugin_config(self, name: str) -> dict: + async def get_plugin_config(self, name: str) -> dict | None: """Return a plugin configuration give a plugin name. Args: name: The name of the plugin of which to return the plugin configuration. Returns: - A list of plugin configurations. + A plugin configuration dict, or None if not found. Examples: >>> import asyncio @@ -92,19 +86,17 @@ async def get_plugin_config(self, name: str) -> dict: >>> c["name"] == "DenyListPlugin" True """ - for plug in self._config.plugins: - if plug.name.lower() == name.lower(): - return plug.model_dump() + if self._config.plugins: + for plug in self._config.plugins: + if plug.name.lower() == name.lower(): + return plug.model_dump() return None - async def invoke_hook( - self, payload_model: Type[P], hook_function: Callable[[Plugin], Callable[[P, PluginContext], PluginResult]], plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any] - ) -> dict: + async def invoke_hook(self, hook_type: str, plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: """Invoke a plugin hook. Args: - payload_model: The type of the payload accepted for the hook. - hook_function: The hook function to be invoked. + hook_type: The type of hook function to be invoked. plugin_name: The name of the plugin to execute. payload: The prompt name and arguments to be analyzed. context: The contextual and state information required for the execution of the hook. @@ -119,37 +111,32 @@ async def invoke_hook( >>> import asyncio >>> import os >>> os.environ["PYTHONPATH"] = "." - >>> from mcpgateway.plugins.framework import GlobalContext, PromptPrehookPayload, PluginContext, PromptPrehookResult + >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, PromptHookType, PromptPrehookPayload, PluginContext, PromptPrehookResult >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") - >>> def prompt_pre_fetch_func(plugin: Plugin, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - ... return plugin.prompt_pre_fetch(payload, context) - >>> payload = PromptPrehookPayload(prompt_id="test_id", args={"user": "This is so innovative"}) + >>> payload = PromptPrehookPayload(prompt_id="123", name="test_prompt", args={"user": "This is so innovative"}) >>> context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) >>> initialized = asyncio.run(server.initialize()) >>> initialized True - >>> result = asyncio.run(server.invoke_hook(PromptPrehookPayload, prompt_pre_fetch_func, "DenyListPlugin", payload.model_dump(), context.model_dump())) + >>> result = asyncio.run(server.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, "DenyListPlugin", payload.model_dump(), context.model_dump())) >>> result is not None True >>> result["result"]["continue_processing"] False """ global_plugin_manager = PluginManager() - plugin_timeout = global_plugin_manager.config.plugin_settings.plugin_timeout if global_plugin_manager.config else DEFAULT_PLUGIN_TIMEOUT - plugin = global_plugin_manager.get_plugin(plugin_name) result_payload: dict[str, Any] = {PLUGIN_NAME: plugin_name} try: - if plugin: - _payload = payload_model.model_validate(payload) - _context = PluginContext.model_validate(context) - result = await asyncio.wait_for(hook_function(plugin, _payload, _context), plugin_timeout) - result_payload[RESULT] = result.model_dump() - if not _context.is_empty(): - result_payload[CONTEXT] = _context.model_dump() - return result_payload - raise ValueError(f"Unable to retrieve plugin {plugin_name} to execute.") - except asyncio.TimeoutError: - result_payload[ERROR] = PluginErrorModel(message=f"Plugin {plugin_name} timed out from execution after {plugin_timeout} seconds.", plugin_name=plugin_name).model_dump() + _context = PluginContext.model_validate(context) + + result = await global_plugin_manager.invoke_hook_for_plugin(plugin_name, hook_type, payload, _context, payload_as_json=True) + + result_payload[RESULT] = result.model_dump() + if not _context.is_empty(): + result_payload[CONTEXT] = _context.model_dump() + return result_payload + except PluginError as pe: + result_payload[ERROR] = pe.error return result_payload except Exception as ex: logger.exception(ex) diff --git a/mcpgateway/plugins/framework/hooks/__init__.py b/mcpgateway/plugins/framework/hooks/__init__.py new file mode 100644 index 000000000..31153c3b7 --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/hooks/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Plugins hooks package. +Exposes predefined hooks for plugins +""" diff --git a/mcpgateway/plugins/framework/hooks/agents.py b/mcpgateway/plugins/framework/hooks/agents.py new file mode 100644 index 000000000..eea547c9a --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/agents.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/models/agents.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic models for agent plugins. +This module implements the pydantic models associated with +the base plugin layer including configurations, and contexts. +""" + +# Standard +from enum import Enum +from typing import Any, Dict, List, Optional + +# Third-Party +from pydantic import Field + +# First-Party +from mcpgateway.common.models import Message +from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + + +class AgentHookType(str, Enum): + """Agent hook points. + + Attributes: + AGENT_PRE_INVOKE: Before agent invocation. + AGENT_POST_INVOKE: After agent responds. + + Examples: + >>> AgentHookType.AGENT_PRE_INVOKE + + >>> AgentHookType.AGENT_PRE_INVOKE.value + 'agent_pre_invoke' + >>> AgentHookType('agent_post_invoke') + + >>> list(AgentHookType) + [, ] + """ + + AGENT_PRE_INVOKE = "agent_pre_invoke" + AGENT_POST_INVOKE = "agent_post_invoke" + + +class AgentPreInvokePayload(PluginPayload): + """Agent payload for pre-invoke hook. + + Attributes: + agent_id: The agent identifier (can be modified for routing). + messages: Conversation messages (can be filtered/transformed). + tools: Optional list of tools available to agent. + headers: Optional HTTP headers. + model: Optional model override. + system_prompt: Optional system instructions. + parameters: Optional LLM parameters (temperature, max_tokens, etc.). + + Examples: + >>> payload = AgentPreInvokePayload(agent_id="agent-123", messages=[]) + >>> payload.agent_id + 'agent-123' + >>> payload.messages + [] + >>> payload.tools is None + True + >>> from mcpgateway.common.models import Message, Role, TextContent + >>> msg = Message(role=Role.USER, content=TextContent(type="text", text="Hello")) + >>> payload = AgentPreInvokePayload( + ... agent_id="agent-456", + ... messages=[msg], + ... tools=["search", "calculator"], + ... model="claude-3-5-sonnet-20241022" + ... ) + >>> payload.tools + ['search', 'calculator'] + >>> payload.model + 'claude-3-5-sonnet-20241022' + """ + + agent_id: str + messages: List[Message] + tools: Optional[List[str]] = None + headers: Optional[HttpHeaderPayload] = None + model: Optional[str] = None + system_prompt: Optional[str] = None + parameters: Optional[Dict[str, Any]] = Field(default_factory=dict) + + +class AgentPostInvokePayload(PluginPayload): + """Agent payload for post-invoke hook. + + Attributes: + agent_id: The agent identifier. + messages: Response messages from agent (can be filtered/transformed). + tool_calls: Optional tool invocations made by agent. + + Examples: + >>> payload = AgentPostInvokePayload(agent_id="agent-123", messages=[]) + >>> payload.agent_id + 'agent-123' + >>> payload.messages + [] + >>> payload.tool_calls is None + True + >>> from mcpgateway.common.models import Message, Role, TextContent + >>> msg = Message(role=Role.ASSISTANT, content=TextContent(type="text", text="Response")) + >>> payload = AgentPostInvokePayload( + ... agent_id="agent-456", + ... messages=[msg], + ... tool_calls=[{"name": "search", "arguments": {"query": "test"}}] + ... ) + >>> payload.tool_calls + [{'name': 'search', 'arguments': {'query': 'test'}}] + """ + + agent_id: str + messages: List[Message] + tool_calls: Optional[List[Dict[str, Any]]] = None + + +AgentPreInvokeResult = PluginResult[AgentPreInvokePayload] +AgentPostInvokeResult = PluginResult[AgentPostInvokePayload] + + +def _register_agent_hooks() -> None: + """Register agent hooks in the global registry. + + This is called lazily to avoid circular import issues. + """ + # Import here to avoid circular dependency at module load time + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel + + registry = get_hook_registry() + + # Only register if not already registered (idempotent) + if not registry.is_registered(AgentHookType.AGENT_PRE_INVOKE): + registry.register_hook(AgentHookType.AGENT_PRE_INVOKE, AgentPreInvokePayload, AgentPreInvokeResult) + registry.register_hook(AgentHookType.AGENT_POST_INVOKE, AgentPostInvokePayload, AgentPostInvokeResult) + + +_register_agent_hooks() diff --git a/mcpgateway/plugins/framework/hooks/http.py b/mcpgateway/plugins/framework/hooks/http.py new file mode 100644 index 000000000..cd8c4e120 --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/http.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/models/http.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic models for http hooks and payloads. +""" + +# Third-Party +from pydantic import RootModel + +# First-Party +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + + +class HttpHeaderPayload(RootModel[dict[str, str]], PluginPayload): + """An HTTP dictionary of headers used in the pre/post HTTP forwarding hooks.""" + + def __iter__(self): # type: ignore[no-untyped-def] + """Custom iterator function to override root attribute. + + Returns: + A custom iterator for header dictionary. + """ + return iter(self.root) + + def __getitem__(self, item: str) -> str: + """Custom getitem function to override root attribute. + + Args: + item: The http header key. + + Returns: + A custom accesser for the header dictionary. + """ + return self.root[item] + + def __setitem__(self, key: str, value: str) -> None: + """Custom setitem function to override root attribute. + + Args: + key: The http header key. + value: The http header value to be set. + """ + self.root[key] = value + + def __len__(self) -> int: + """Custom len function to override root attribute. + + Returns: + The len of the header dictionary. + """ + return len(self.root) + + +HttpHeaderPayloadResult = PluginResult[HttpHeaderPayload] diff --git a/mcpgateway/plugins/framework/hooks/prompts.py b/mcpgateway/plugins/framework/hooks/prompts.py new file mode 100644 index 000000000..d57e6bf34 --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/prompts.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/hooks/prompts.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic models for prompt plugins. +This module implements the pydantic models associated with +the base plugin layer including configurations, and contexts. +""" + +# Standard +from enum import Enum +from typing import Optional + +# Third-Party +from pydantic import Field + +# First-Party +from mcpgateway.common.models import PromptResult +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + + +class PromptHookType(str, Enum): + """MCP Forge Gateway hook points. + + Attributes: + prompt_pre_fetch: The prompt pre hook. + prompt_post_fetch: The prompt post hook. + tool_pre_invoke: The tool pre invoke hook. + tool_post_invoke: The tool post invoke hook. + resource_pre_fetch: The resource pre fetch hook. + resource_post_fetch: The resource post fetch hook. + + Examples: + >>> PromptHookType.PROMPT_PRE_FETCH + + >>> PromptHookType.PROMPT_PRE_FETCH.value + 'prompt_pre_fetch' + >>> PromptHookType('prompt_post_fetch') + + >>> list(PromptHookType) + [, ] + """ + + PROMPT_PRE_FETCH = "prompt_pre_fetch" + PROMPT_POST_FETCH = "prompt_post_fetch" + + +class PromptPrehookPayload(PluginPayload): + """A prompt payload for a prompt prehook. + + Attributes: + prompt_id (str): The ID of the prompt template. + args (dic[str,str]): The prompt template arguments. + + Examples: + >>> payload = PromptPrehookPayload(prompt_id="123", args={"user": "alice"}) + >>> payload.prompt_id + '123' + >>> payload.args + {'user': 'alice'} + >>> payload2 = PromptPrehookPayload(prompt_id="empty") + >>> payload2.args + {} + >>> p = PromptPrehookPayload(prompt_id="123", args={"name": "Bob", "time": "morning"}) + >>> p.prompt_id + '123' + >>> p.args["name"] + 'Bob' + """ + + prompt_id: str + args: Optional[dict[str, str]] = Field(default_factory=dict) + + +class PromptPosthookPayload(PluginPayload): + """A prompt payload for a prompt posthook. + + Attributes: + prompt_id (str): The prompt ID. + result (PromptResult): The prompt after its template is rendered. + + Examples: + >>> from mcpgateway.common.models import PromptResult, Message, TextContent + >>> msg = Message(role="user", content=TextContent(type="text", text="Hello World")) + >>> result = PromptResult(messages=[msg]) + >>> payload = PromptPosthookPayload(prompt_id="123", result=result) + >>> payload.prompt_id + '123' + >>> payload.result.messages[0].content.text + 'Hello World' + >>> from mcpgateway.common.models import PromptResult, Message, TextContent + >>> msg = Message(role="assistant", content=TextContent(type="text", text="Test output")) + >>> r = PromptResult(messages=[msg]) + >>> p = PromptPosthookPayload(prompt_id="123", result=r) + >>> p.prompt_id + '123' + """ + + prompt_id: str + result: PromptResult + + +PromptPrehookResult = PluginResult[PromptPrehookPayload] +PromptPosthookResult = PluginResult[PromptPosthookPayload] + + +def _register_prompt_hooks() -> None: + """Register prompt hooks in the global registry. + + This is called lazily to avoid circular import issues. + """ + # Import here to avoid circular dependency at module load time + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel + + registry = get_hook_registry() + + # Only register if not already registered (idempotent) + if not registry.is_registered(PromptHookType.PROMPT_PRE_FETCH): + registry.register_hook(PromptHookType.PROMPT_PRE_FETCH, PromptPrehookPayload, PromptPrehookResult) + registry.register_hook(PromptHookType.PROMPT_POST_FETCH, PromptPosthookPayload, PromptPosthookResult) + + +_register_prompt_hooks() diff --git a/mcpgateway/plugins/framework/hooks/registry.py b/mcpgateway/plugins/framework/hooks/registry.py new file mode 100644 index 000000000..177175471 --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/registry.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/hook_registry.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Hook Registry. +This module provides a global registry for mapping hook types to their +corresponding payload and result Pydantic models. This enables external +plugins to properly serialize/deserialize payloads without needing direct +access to the specific plugin implementations. +""" + +# Standard +from typing import Dict, Optional, Type, Union + +# First-Party +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + + +class HookRegistry: + """Global registry for hook type metadata. + + This singleton registry maintains mappings between hook type names and their + associated Pydantic models for payloads and results. It enables dynamic + serialization/deserialization for external plugins. + + Examples: + >>> from mcpgateway.plugins.framework import PluginPayload, PluginResult + >>> registry = HookRegistry() + >>> registry.register_hook("test_hook", PluginPayload, PluginResult) + >>> registry.get_payload_type("test_hook") + + >>> registry.get_result_type("test_hook") + + """ + + _instance: Optional["HookRegistry"] = None + _hook_payloads: Dict[str, Type[PluginPayload]] = {} + _hook_results: Dict[str, Type[PluginResult]] = {} + + def __new__(cls) -> "HookRegistry": + """Ensure singleton pattern for the registry. + + Returns: + The singleton HookRegistry instance. + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def register_hook( + self, + hook_type: str, + payload_class: Type[PluginPayload], + result_class: Type[PluginResult], + ) -> None: + """Register a hook type with its payload and result classes. + + Args: + hook_type: The hook type identifier (e.g., "prompt_pre_fetch"). + payload_class: The Pydantic model class for the hook's payload. + result_class: The Pydantic model class for the hook's result. + + Examples: + >>> registry = HookRegistry() + >>> from mcpgateway.plugins.framework import PluginPayload, PluginResult + >>> registry.register_hook("custom_hook", PluginPayload, PluginResult) + """ + self._hook_payloads[hook_type] = payload_class + self._hook_results[hook_type] = result_class + + def get_payload_type(self, hook_type: str) -> Optional[Type[PluginPayload]]: + """Get the payload class for a hook type. + + Args: + hook_type: The hook type identifier. + + Returns: + The Pydantic payload class, or None if not registered. + + Examples: + >>> registry = HookRegistry() + >>> registry.get_payload_type("unknown_hook") + """ + return self._hook_payloads.get(hook_type) + + def get_result_type(self, hook_type: str) -> Optional[Type[PluginResult]]: + """Get the result class for a hook type. + + Args: + hook_type: The hook type identifier. + + Returns: + The Pydantic result class, or None if not registered. + + Examples: + >>> registry = HookRegistry() + >>> registry.get_result_type("unknown_hook") + """ + return self._hook_results.get(hook_type) + + def json_to_payload(self, hook_type: str, payload: Union[str, dict]) -> PluginPayload: + """Convert JSON to the appropriate payload Pydantic model. + + Args: + hook_type: The hook type identifier. + payload: The payload as JSON string or dictionary. + + Returns: + The deserialized Pydantic payload object. + + Raises: + ValueError: If the hook type is not registered. + + Examples: + >>> registry = HookRegistry() + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptPrehookPayload, PromptPrehookResult + >>> registry.register_hook("test", PromptPrehookPayload, PromptPrehookResult) + >>> payload = registry.json_to_payload("test", {"prompt_id": "123"}) + """ + payload_class = self.get_payload_type(hook_type) + if not payload_class: + raise ValueError(f"No payload type registered for hook: {hook_type}") + + if isinstance(payload, str): + return payload_class.model_validate_json(payload) + return payload_class.model_validate(payload) + + def json_to_result(self, hook_type: str, result: Union[str, dict]) -> PluginResult: + """Convert JSON to the appropriate result Pydantic model. + + Args: + hook_type: The hook type identifier. + result: The result as JSON string or dictionary. + + Returns: + The deserialized Pydantic result object. + + Raises: + ValueError: If the hook type is not registered. + + Examples: + >>> registry = HookRegistry() + >>> from mcpgateway.plugins.framework import PluginPayload, PluginResult + >>> registry.register_hook("test", PluginPayload, PluginResult) + >>> result = registry.json_to_result("test", '{"continue_processing": true}') + """ + result_class = self.get_result_type(hook_type) + if not result_class: + raise ValueError(f"No result type registered for hook: {hook_type}") + + if isinstance(result, str): + return result_class.model_validate_json(result) + return result_class.model_validate(result) + + def is_registered(self, hook_type: str) -> bool: + """Check if a hook type is registered. + + Args: + hook_type: The hook type identifier. + + Returns: + True if the hook is registered, False otherwise. + + Examples: + >>> registry = HookRegistry() + >>> registry.is_registered("unknown") + False + """ + return hook_type in self._hook_payloads and hook_type in self._hook_results + + def get_registered_hooks(self) -> list[str]: + """Get all registered hook types. + + Returns: + List of registered hook type identifiers. + + Examples: + >>> registry = HookRegistry() + >>> hooks = registry.get_registered_hooks() + >>> isinstance(hooks, list) + True + """ + return list(self._hook_payloads.keys()) + + +# Global singleton instance +_global_registry = HookRegistry() + + +def get_hook_registry() -> HookRegistry: + """Get the global hook registry instance. + + Returns: + The singleton HookRegistry instance. + + Examples: + >>> registry = get_hook_registry() + >>> isinstance(registry, HookRegistry) + True + """ + return _global_registry diff --git a/mcpgateway/plugins/framework/hooks/resources.py b/mcpgateway/plugins/framework/hooks/resources.py new file mode 100644 index 000000000..b31439130 --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/resources.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/hooks/resources.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic models for resource hooks. +""" + +# Standard +from enum import Enum +from typing import Any, Optional + +# Third-Party +from pydantic import Field + +# First-Party +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + + +class ResourceHookType(str, Enum): + """MCP Forge Gateway resource hook points. + + Attributes: + resource_pre_fetch: The resource pre fetch hook. + resource_post_fetch: The resource post fetch hook. + + Examples: + >>> ResourceHookType.RESOURCE_PRE_FETCH + + >>> ResourceHookType.RESOURCE_PRE_FETCH.value + 'resource_pre_fetch' + >>> ResourceHookType('resource_post_fetch') + + >>> list(ResourceHookType) + [, ] + """ + + RESOURCE_PRE_FETCH = "resource_pre_fetch" + RESOURCE_POST_FETCH = "resource_post_fetch" + + +class ResourcePreFetchPayload(PluginPayload): + """A resource payload for a resource pre-fetch hook. + + Attributes: + uri: The resource URI. + metadata: Optional metadata for the resource request. + + Examples: + >>> payload = ResourcePreFetchPayload(uri="file:///data.txt") + >>> payload.uri + 'file:///data.txt' + >>> payload2 = ResourcePreFetchPayload(uri="http://api/data", metadata={"Accept": "application/json"}) + >>> payload2.metadata + {'Accept': 'application/json'} + >>> p = ResourcePreFetchPayload(uri="file:///docs/readme.md", metadata={"version": "1.0"}) + >>> p.uri + 'file:///docs/readme.md' + >>> p.metadata["version"] + '1.0' + """ + + uri: str + metadata: Optional[dict[str, Any]] = Field(default_factory=dict) + + +class ResourcePostFetchPayload(PluginPayload): + """A resource payload for a resource post-fetch hook. + + Attributes: + uri: The resource URI. + content: The fetched resource content. + + Examples: + >>> from mcpgateway.common.models import ResourceContent + >>> content = ResourceContent(type="resource", id="res-1", uri="file:///data.txt", + ... text="Hello World") + >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) + >>> payload.uri + 'file:///data.txt' + >>> payload.content.text + 'Hello World' + >>> from mcpgateway.common.models import ResourceContent + >>> resource_content = ResourceContent(type="resource", id="res-2", uri="test://resource", text="Test data") + >>> p = ResourcePostFetchPayload(uri="test://resource", content=resource_content) + >>> p.uri + 'test://resource' + """ + + uri: str + content: Any + + +ResourcePreFetchResult = PluginResult[ResourcePreFetchPayload] +ResourcePostFetchResult = PluginResult[ResourcePostFetchPayload] + + +def _register_resource_hooks() -> None: + """Register resource hooks in the global registry. + + This is called lazily to avoid circular import issues. + """ + # Import here to avoid circular dependency at module load time + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel + + registry = get_hook_registry() + + # Only register if not already registered (idempotent) + if not registry.is_registered(ResourceHookType.RESOURCE_PRE_FETCH): + registry.register_hook(ResourceHookType.RESOURCE_PRE_FETCH, ResourcePreFetchPayload, ResourcePreFetchResult) + registry.register_hook(ResourceHookType.RESOURCE_POST_FETCH, ResourcePostFetchPayload, ResourcePostFetchResult) + + +_register_resource_hooks() diff --git a/mcpgateway/plugins/framework/hooks/tools.py b/mcpgateway/plugins/framework/hooks/tools.py new file mode 100644 index 000000000..7560d05b0 --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/tools.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/hooks/tools.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic models for tool hooks. +""" + +# Standard +from enum import Enum +from typing import Any, Optional + +# Third-Party +from pydantic import Field + +# First-Party +from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + + +class ToolHookType(str, Enum): + """MCP Forge Gateway hook points. + + Attributes: + tool_pre_invoke: The tool pre invoke hook. + tool_post_invoke: The tool post invoke hook. + + Examples: + >>> ToolHookType.TOOL_PRE_INVOKE + + >>> ToolHookType.TOOL_PRE_INVOKE.value + 'tool_pre_invoke' + >>> ToolHookType('tool_post_invoke') + + >>> list(ToolHookType) + [, ] + """ + + TOOL_PRE_INVOKE = "tool_pre_invoke" + TOOL_POST_INVOKE = "tool_post_invoke" + + +class ToolPreInvokePayload(PluginPayload): + """A tool payload for a tool pre-invoke hook. + + Args: + name: The tool name. + args: The tool arguments for invocation. + headers: The http pass through headers. + + Examples: + >>> payload = ToolPreInvokePayload(name="test_tool", args={"input": "data"}) + >>> payload.name + 'test_tool' + >>> payload.args + {'input': 'data'} + >>> payload2 = ToolPreInvokePayload(name="empty") + >>> payload2.args + {} + >>> p = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3}) + >>> p.name + 'calculator' + >>> p.args["operation"] + 'add' + + """ + + name: str + args: Optional[dict[str, Any]] = Field(default_factory=dict) + headers: Optional[HttpHeaderPayload] = None + + +class ToolPostInvokePayload(PluginPayload): + """A tool payload for a tool post-invoke hook. + + Args: + name: The tool name. + result: The tool invocation result. + + Examples: + >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"}) + >>> payload.name + 'calculator' + >>> payload.result + {'result': 8, 'status': 'success'} + >>> p = ToolPostInvokePayload(name="analyzer", result={"confidence": 0.95, "sentiment": "positive"}) + >>> p.name + 'analyzer' + >>> p.result["confidence"] + 0.95 + """ + + name: str + result: Any + + +ToolPreInvokeResult = PluginResult[ToolPreInvokePayload] +ToolPostInvokeResult = PluginResult[ToolPostInvokePayload] + + +def _register_tool_hooks() -> None: + """Register Tool hooks in the global registry. + + This is called lazily to avoid circular import issues. + """ + # Import here to avoid circular dependency at module load time + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel + + registry = get_hook_registry() + + # Only register if not already registered (idempotent) + if not registry.is_registered(ToolHookType.TOOL_PRE_INVOKE): + registry.register_hook(ToolHookType.TOOL_PRE_INVOKE, ToolPreInvokePayload, ToolPreInvokeResult) + registry.register_hook(ToolHookType.TOOL_POST_INVOKE, ToolPostInvokePayload, ToolPostInvokeResult) + + +_register_tool_hooks() diff --git a/mcpgateway/plugins/framework/loader/plugin.py b/mcpgateway/plugins/framework/loader/plugin.py index c1dbdc170..1fd9bd9c0 100644 --- a/mcpgateway/plugins/framework/loader/plugin.py +++ b/mcpgateway/plugins/framework/loader/plugin.py @@ -72,6 +72,7 @@ def __register_plugin_type(self, kind: str) -> None: kind: The fully-qualified type of the plugin to be registered. """ if kind not in self._plugin_types: + plugin_type: Type[Plugin] if kind == EXTERNAL_PLUGIN_TYPE: plugin_type = ExternalPlugin else: diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index ca41127b9..09d05dcdc 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -20,8 +20,9 @@ >>> # await manager.initialize() # Called in async context >>> # Create test payload and context - >>> from mcpgateway.plugins.framework.models import PromptPrehookPayload, GlobalContext - >>> payload = PromptPrehookPayload(prompt_id="test", name="test", args={"user": "input"}) + >>> from mcpgateway.plugins.framework.models import GlobalContext + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptPrehookPayload + >>> payload = PromptPrehookPayload(prompt_id="123", name="test", args={"user": "input"}) >>> context = GlobalContext(request_id="123") >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) # Called in async context """ @@ -30,61 +31,29 @@ import asyncio from copy import deepcopy import logging -import time -from typing import Any, Callable, Coroutine, Dict, Generic, Optional, Tuple, TypeVar +from typing import Any, Optional, Union # First-Party -from mcpgateway.plugins.framework.base import Plugin, PluginRef +from mcpgateway.plugins.framework.base import HookRef, Plugin from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError, PluginViolationError from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework.models import ( Config, GlobalContext, - HookType, - PluginCondition, PluginContext, PluginContextTable, PluginErrorModel, PluginMode, + PluginPayload, PluginResult, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, ) from mcpgateway.plugins.framework.registry import PluginInstanceRegistry -from mcpgateway.plugins.framework.utils import ( - post_prompt_matches, - post_resource_matches, - post_tool_matches, - pre_prompt_matches, - pre_resource_matches, - pre_tool_matches, -) +from mcpgateway.plugins.framework.utils import payload_matches # Use standard logging to avoid circular imports (plugins -> services -> plugins) logger = logging.getLogger(__name__) -T = TypeVar( - "T", - PromptPosthookPayload, - PromptPrehookPayload, - ResourcePostFetchPayload, - ResourcePreFetchPayload, - ToolPostInvokePayload, - ToolPreInvokePayload, -) - - # Configuration constants DEFAULT_PLUGIN_TIMEOUT = 30 # seconds MAX_PAYLOAD_SIZE = 1_000_000 # 1MB @@ -100,7 +69,7 @@ class PayloadSizeError(ValueError): """Raised when a payload exceeds the maximum allowed size.""" -class PluginExecutor(Generic[T]): +class PluginExecutor: """Executes a list of plugins with timeout protection and error handling. This class manages the execution of plugins in priority order, handling: @@ -110,8 +79,7 @@ class PluginExecutor(Generic[T]): - Metadata aggregation from multiple plugins Examples: - >>> from mcpgateway.plugins.framework import PromptPrehookPayload - >>> executor = PluginExecutor[PromptPrehookPayload]() + >>> executor = PluginExecutor() >>> # In async context: >>> # result, contexts = await executor.execute( >>> # plugins=[plugin1, plugin2], @@ -134,22 +102,20 @@ def __init__(self, config: Optional[Config] = None, timeout: int = DEFAULT_PLUGI async def execute( self, - plugins: list[PluginRef], - payload: T, + hook_refs: list[HookRef], + payload: PluginPayload, global_context: GlobalContext, - plugin_run: Callable[[PluginRef, T, PluginContext], Coroutine[Any, Any, PluginResult[T]]], - compare: Callable[[T, list[PluginCondition], GlobalContext], bool], + hook_type: str, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False, - ) -> tuple[PluginResult[T], PluginContextTable | None]: + ) -> tuple[PluginResult, PluginContextTable | None]: """Execute plugins in priority order with timeout protection. Args: - plugins: List of plugins to execute, sorted by priority. + hook_refs: List of hook references to execute, sorted by priority. payload: The payload to be processed by plugins. global_context: Shared context for all plugins containing request metadata. - plugin_run: Async function to execute a specific plugin hook. - compare: Function to check if plugin conditions match the current context. + hook_type: The hook type identifier (e.g., "tool_pre_invoke"). local_contexts: Optional existing contexts from previous hook executions. violations_as_exceptions: Raise violations as exceptions rather than as returns. @@ -165,38 +131,37 @@ async def execute( Examples: >>> # Execute plugins with timeout protection - >>> from mcpgateway.plugins.framework import HookType + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> executor = PluginExecutor(timeout=30) >>> # Assuming you have a registry instance: - >>> # plugins = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + >>> # plugins = registry.get_plugins_for_hook(PromptHookType.PROMPT_PRE_FETCH) >>> # In async context: >>> # result, contexts = await executor.execute( >>> # plugins=plugins, - >>> # payload=PromptPrehookPayload(prompt_id="123", args={}), + >>> # payload=PromptPrehookPayload(prompt_id="123", name="test", args={}), >>> # global_context=GlobalContext(request_id="123"), >>> # plugin_run=pre_prompt_fetch, >>> # compare=pre_prompt_matches >>> # ) """ - if not plugins: - return (PluginResult[T](modified_payload=None), None) + if not hook_refs: + return (PluginResult(modified_payload=None), None) # Validate payload size self._validate_payload_size(payload) res_local_contexts = {} - combined_metadata = {} - current_payload: T | None = None + combined_metadata: dict[str, Any] = {} + current_payload: PluginPayload | None = None - for pluginref in plugins: + for hook_ref in hook_refs: # Skip disabled plugins - if pluginref.mode == PluginMode.DISABLED: - logger.debug(f"Skipping disabled plugin {pluginref.name}") + if hook_ref.plugin_ref.mode == PluginMode.DISABLED: continue # Check if plugin conditions match current context - if pluginref.conditions and not compare(payload, pluginref.conditions, global_context): - logger.debug(f"Skipping plugin {pluginref.name} - conditions not met") + if hook_ref.plugin_ref.conditions and not payload_matches(payload, hook_type, hook_ref.plugin_ref.conditions, global_context): + logger.debug("Skipping plugin %s - conditions not met", hook_ref.plugin_ref.name) continue tmp_global_context = GlobalContext( @@ -208,7 +173,7 @@ async def execute( metadata={} if not global_context.metadata else deepcopy(global_context.metadata), ) # Get or create local context for this plugin - local_context_key = global_context.request_id + pluginref.uuid + local_context_key = global_context.request_id + hook_ref.plugin_ref.uuid if local_contexts and local_context_key in local_contexts: local_context = local_contexts[local_context_key] local_context.global_context = tmp_global_context @@ -216,69 +181,130 @@ async def execute( local_context = PluginContext(global_context=tmp_global_context) res_local_contexts[local_context_key] = local_context - try: - # Execute plugin with timeout protection - result = await self._execute_with_timeout(pluginref, plugin_run, current_payload or payload, local_context) - if local_context.global_context: - global_context.state.update(local_context.global_context.state) - global_context.metadata.update(local_context.global_context.metadata) - # Aggregate metadata from all plugins - if result.metadata: - combined_metadata.update(result.metadata) - - # Track payload modifications - if result.modified_payload is not None: - current_payload = result.modified_payload - - # Set plugin name in violation if present - if result.violation: - result.violation.plugin_name = pluginref.plugin.name - - # Handle plugin blocking the request - if not result.continue_processing: - if pluginref.plugin.mode == PluginMode.ENFORCE: - logger.warning(f"Plugin {pluginref.plugin.name} blocked request in enforce mode") - if violations_as_exceptions: - if result.violation: - plugin_name = result.violation.plugin_name - violation_reason = result.violation.reason - violation_desc = result.violation.description - violation_code = result.violation.code - raise PluginViolationError( - f"{plugin_run.__name__} blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})", violation=result.violation - ) - raise PluginViolationError(f"{plugin_run.__name__} blocked by plugin") - return (PluginResult[T](continue_processing=False, modified_payload=current_payload, violation=result.violation, metadata=combined_metadata), res_local_contexts) - if pluginref.plugin.mode == PluginMode.PERMISSIVE: - logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.violation.description if result.violation else 'No description'}") - - except asyncio.TimeoutError: - logger.error(f"Plugin {pluginref.name} timed out after {self.timeout}s") - if self.config.plugin_settings.fail_on_plugin_error or pluginref.plugin.mode == PluginMode.ENFORCE: - raise PluginError(error=PluginErrorModel(message=f"Plugin {pluginref.name} exceeded {self.timeout}s timeout", plugin_name=pluginref.name)) - # In permissive or enforce_ignore_error mode, continue with next plugin - continue - except PluginViolationError: - raise - except PluginError as pe: - logger.error(f"Plugin {pluginref.name} failed with error: {str(pe)}", exc_info=True) - if self.config.plugin_settings.fail_on_plugin_error or pluginref.plugin.mode == PluginMode.ENFORCE: - raise - except Exception as e: - logger.error(f"Plugin {pluginref.name} failed with error: {str(e)}", exc_info=True) - if self.config.plugin_settings.fail_on_plugin_error or pluginref.plugin.mode == PluginMode.ENFORCE: - raise PluginError(error=convert_exception_to_error(e, pluginref.name)) - # In permissive or enforce_ignore_error mode, continue with next plugin - continue + # Execute plugin with timeout protection + result = await self.execute_plugin( + hook_ref, + current_payload or payload, + local_context, + violations_as_exceptions, + global_context, + combined_metadata, + ) + # Track payload modifications + if result.modified_payload is not None: + current_payload = result.modified_payload + if not result.continue_processing and hook_ref.plugin_ref.plugin.mode == PluginMode.ENFORCE: + return (result, res_local_contexts) + + return ( + PluginResult(continue_processing=True, modified_payload=current_payload, violation=None, metadata=combined_metadata), + res_local_contexts, + ) + + async def execute_plugin( + self, + hook_ref: HookRef, + payload: PluginPayload, + local_context: PluginContext, + violations_as_exceptions: bool, + global_context: Optional[GlobalContext] = None, + combined_metadata: Optional[dict[str, Any]] = None, + ) -> PluginResult: + """Execute a single plugin with timeout protection. + + Args: + hook_ref: Hooking structure that contains the plugin and hook. + payload: The payload to be processed by plugins. + local_context: local context. + violations_as_exceptions: Raise violations as exceptions rather than as returns. + global_context: Shared context for all plugins containing request metadata. + combined_metadata: combination of the metadata of all plugins. - return (PluginResult[T](continue_processing=True, modified_payload=current_payload, violation=None, metadata=combined_metadata), res_local_contexts) + Returns: + A tuple containing: + - PluginResult with processing status, modified payload, and metadata + - PluginContextTable with updated local contexts for each plugin - async def _execute_with_timeout(self, pluginref: PluginRef, plugin_run: Callable, payload: T, context: PluginContext) -> PluginResult[T]: + Raises: + PayloadSizeError: If the payload exceeds MAX_PAYLOAD_SIZE. + PluginError: If there is an error inside a plugin. + PluginViolationError: If a violation occurs and violation_as_exceptions is set. + """ + try: + # Execute plugin with timeout protection + result = await self._execute_with_timeout(hook_ref, payload, local_context) + if local_context.global_context and global_context: + global_context.state.update(local_context.global_context.state) + global_context.metadata.update(local_context.global_context.metadata) + # Aggregate metadata from all plugins + if result.metadata and combined_metadata is not None: + combined_metadata.update(result.metadata) + + # Track payload modifications + # if result.modified_payload is not None: + # current_payload = result.modified_payload + + # Set plugin name in violation if present + if result.violation: + result.violation.plugin_name = hook_ref.plugin_ref.plugin.name + + # Handle plugin blocking the request + if not result.continue_processing: + if hook_ref.plugin_ref.plugin.mode == PluginMode.ENFORCE: + logger.warning("Plugin %s blocked request in enforce mode", hook_ref.plugin_ref.plugin.name) + if violations_as_exceptions: + if result.violation: + plugin_name = result.violation.plugin_name + violation_reason = result.violation.reason + violation_desc = result.violation.description + violation_code = result.violation.code + raise PluginViolationError( + f"{hook_ref.name} blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})", + violation=result.violation, + ) + raise PluginViolationError(f"{hook_ref.name} blocked by plugin") + return PluginResult( + continue_processing=False, + modified_payload=payload, + violation=result.violation, + metadata=combined_metadata, + ) + if hook_ref.plugin_ref.plugin.mode == PluginMode.PERMISSIVE: + logger.warning( + "Plugin %s would block (permissive mode): %s", + hook_ref.plugin_ref.plugin.name, + result.violation.description if result.violation else "No description", + ) + return result + except asyncio.TimeoutError as exc: + logger.error("Plugin %s timed out after %ds", hook_ref.plugin_ref.name, self.timeout) + if (self.config and self.config.plugin_settings.fail_on_plugin_error) or hook_ref.plugin_ref.plugin.mode == PluginMode.ENFORCE: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin {hook_ref.plugin_ref.name} exceeded {self.timeout}s timeout", + plugin_name=hook_ref.plugin_ref.name, + ) + ) from exc + # In permissive or enforce_ignore_error mode, continue with next plugin + except PluginViolationError: + raise + except PluginError as pe: + logger.error("Plugin %s failed with error: %s", hook_ref.plugin_ref.name, str(pe), exc_info=True) + if (self.config and self.config.plugin_settings.fail_on_plugin_error) or hook_ref.plugin_ref.plugin.mode == PluginMode.ENFORCE: + raise + except Exception as e: + logger.error("Plugin %s failed with error: %s", hook_ref.plugin_ref.name, str(e), exc_info=True) + if (self.config and self.config.plugin_settings.fail_on_plugin_error) or hook_ref.plugin_ref.plugin.mode == PluginMode.ENFORCE: + raise PluginError(error=convert_exception_to_error(e, hook_ref.plugin_ref.name)) from e + # In permissive or enforce_ignore_error mode, continue with next plugin + # Return a result indicating processing should continue despite the error + return PluginResult(continue_processing=True) + + async def _execute_with_timeout(self, hook_ref: HookRef, payload: PluginPayload, context: PluginContext) -> PluginResult: """Execute a plugin with timeout protection. Args: - pluginref: Reference to the plugin to execute. - plugin_run: Function to execute the plugin. + hook_ref: Reference to the hook and plugin to execute. payload: Payload to process. context: Plugin execution context. @@ -305,21 +331,21 @@ async def _execute_with_timeout(self, pluginref: PluginRef, plugin_run: Callable span_id = service.start_span( db=db, trace_id=trace_id, - name=f"plugin.execute.{pluginref.name}", + name=f"plugin.execute.{hook_ref.plugin_ref.name}", kind="internal", resource_type="plugin", - resource_name=pluginref.name, + resource_name=hook_ref.plugin_ref.name, attributes={ - "plugin.name": pluginref.name, - "plugin.uuid": pluginref.uuid, - "plugin.mode": pluginref.mode.value if hasattr(pluginref.mode, "value") else str(pluginref.mode), - "plugin.priority": pluginref.priority, + "plugin.name": hook_ref.plugin_ref.name, + "plugin.uuid": hook_ref.plugin_ref.uuid, + "plugin.mode": hook_ref.plugin_ref.mode.value if hasattr(hook_ref.plugin_ref.mode, "value") else str(hook_ref.plugin_ref.mode), + "plugin.priority": hook_ref.plugin_ref.priority, "plugin.timeout": self.timeout, }, ) # Execute plugin - result = await asyncio.wait_for(plugin_run(pluginref, payload, context), timeout=self.timeout) + result = await asyncio.wait_for(hook_ref.hook(payload, context), timeout=self.timeout) # End span with success service.end_span( @@ -336,12 +362,12 @@ async def _execute_with_timeout(self, pluginref: PluginRef, plugin_run: Callable db.close() else: # No active trace, execute without instrumentation - return await asyncio.wait_for(plugin_run(pluginref, payload, context), timeout=self.timeout) + return await asyncio.wait_for(hook_ref.hook(payload, context), timeout=self.timeout) except Exception as e: # If observability setup fails, continue without instrumentation logger.debug(f"Plugin observability setup failed: {e}") - return await asyncio.wait_for(plugin_run(pluginref, payload, context), timeout=self.timeout) + return await asyncio.wait_for(hook_ref.hook(payload, context), timeout=self.timeout) def _validate_payload_size(self, payload: Any) -> None: """Validate that payload doesn't exceed size limits. @@ -365,154 +391,6 @@ def _validate_payload_size(self, payload: Any) -> None: raise PayloadSizeError(f"Result size {total_size} exceeds limit of {MAX_PAYLOAD_SIZE} bytes") -async def pre_prompt_fetch(plugin: PluginRef, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - """Call plugin's prompt pre-fetch hook. - - Args: - plugin: The plugin to execute. - payload: The prompt payload to be analyzed. - context: Contextual information about the hook call. - - Returns: - The result of the plugin execution. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, PromptPrehookPayload, PluginContext, GlobalContext - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> payload = PromptPrehookPayload(prompt_id="123", args={"key": "value"}) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await pre_prompt_fetch(plugin_ref, payload, context) - """ - return await plugin.plugin.prompt_pre_fetch(payload, context) - - -async def post_prompt_fetch(plugin: PluginRef, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: - """Call plugin's prompt post-fetch hook. - - Args: - plugin: The plugin to execute. - payload: The prompt payload to be analyzed. - context: Contextual information about the hook call. - - Returns: - The result of the plugin execution. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, PromptPosthookPayload, PluginContext, GlobalContext - >>> from mcpgateway.models import PromptResult - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> result = PromptResult(messages=[]) - >>> payload = PromptPosthookPayload(prompt_id="123", result=result) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await post_prompt_fetch(plugin_ref, payload, context) - """ - return await plugin.plugin.prompt_post_fetch(payload, context) - - -async def pre_tool_invoke(plugin: PluginRef, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: - """Call plugin's tool pre-invoke hook. - - Args: - plugin: The plugin to execute. - payload: The tool payload to be analyzed. - context: Contextual information about the hook call. - - Returns: - The result of the plugin execution. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, ToolPreInvokePayload, PluginContext, GlobalContext - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> payload = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3}) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await pre_tool_invoke(plugin_ref, payload, context) - """ - return await plugin.plugin.tool_pre_invoke(payload, context) - - -async def post_tool_invoke(plugin: PluginRef, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: - """Call plugin's tool post-invoke hook. - - Args: - plugin: The plugin to execute. - payload: The tool result payload to be analyzed. - context: Contextual information about the hook call. - - Returns: - The result of the plugin execution. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, ToolPostInvokePayload, PluginContext, GlobalContext - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"}) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await post_tool_invoke(plugin_ref, payload, context) - """ - return await plugin.plugin.tool_post_invoke(payload, context) - - -async def pre_resource_fetch(plugin: PluginRef, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: - """Call plugin's resource pre-fetch hook. - - Args: - plugin: The plugin to execute. - payload: The resource payload to be analyzed. - context: The plugin context. - - Returns: - ResourcePreFetchResult with processing status. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, ResourcePreFetchPayload, PluginContext, GlobalContext - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> payload = ResourcePreFetchPayload(uri="file:///data.txt", metadata={"cache": True}) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await pre_resource_fetch(plugin_ref, payload, context) - """ - return await plugin.plugin.resource_pre_fetch(payload, context) - - -async def post_resource_fetch(plugin: PluginRef, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: - """Call plugin's resource post-fetch hook. - - Args: - plugin: The plugin to execute. - payload: The resource content payload to be analyzed. - context: The plugin context. - - Returns: - ResourcePostFetchResult with processing status. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, ResourcePostFetchPayload, PluginContext, GlobalContext - >>> from mcpgateway.models import ResourceContent - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> content = ResourceContent(type="resource", id="res-1", uri="file:///data.txt", text="Data") - >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await post_resource_fetch(plugin_ref, payload, context) - """ - return await plugin.plugin.resource_post_fetch(payload, context) - - class PluginManager: """Plugin manager for managing the plugin lifecycle. @@ -536,8 +414,9 @@ class PluginManager: >>> # print(f"Loaded {manager.plugin_count} plugins") >>> >>> # Execute prompt hooks - >>> from mcpgateway.plugins.framework import PromptPrehookPayload, GlobalContext - >>> payload = PromptPrehookPayload(prompt_id="123", args={}) + >>> from mcpgateway.plugins.framework.models import GlobalContext + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptPrehookPayload + >>> payload = PromptPrehookPayload(prompt_id="123", name="test", args={}) >>> context = GlobalContext(request_id="req-123") >>> # In async context: >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) @@ -551,16 +430,7 @@ class PluginManager: _initialized: bool = False _registry: PluginInstanceRegistry = PluginInstanceRegistry() _config: Config | None = None - _pre_prompt_executor: PluginExecutor[PromptPrehookPayload] = PluginExecutor[PromptPrehookPayload]() - _post_prompt_executor: PluginExecutor[PromptPosthookPayload] = PluginExecutor[PromptPosthookPayload]() - _pre_tool_executor: PluginExecutor[ToolPreInvokePayload] = PluginExecutor[ToolPreInvokePayload]() - _post_tool_executor: PluginExecutor[ToolPostInvokePayload] = PluginExecutor[ToolPostInvokePayload]() - _resource_pre_executor: PluginExecutor[ResourcePreFetchPayload] = PluginExecutor[ResourcePreFetchPayload]() - _resource_post_executor: PluginExecutor[ResourcePostFetchPayload] = PluginExecutor[ResourcePostFetchPayload]() - - # Context cleanup tracking - _context_store: Dict[str, Tuple[PluginContextTable, float]] = {} - _last_cleanup: float = 0 + _executor: PluginExecutor = PluginExecutor() def __init__(self, config: str = "", timeout: int = DEFAULT_PLUGIN_TIMEOUT): """Initialize plugin manager. @@ -581,23 +451,8 @@ def __init__(self, config: str = "", timeout: int = DEFAULT_PLUGIN_TIMEOUT): self._config = ConfigLoader.load_config(config) # Update executor timeouts - self._pre_prompt_executor.timeout = timeout - self._post_prompt_executor.timeout = timeout - self._pre_tool_executor.timeout = timeout - self._post_tool_executor.timeout = timeout - self._resource_pre_executor.timeout = timeout - self._resource_post_executor.timeout = timeout - self._pre_prompt_executor.config = self._config - self._post_prompt_executor.config = self._config - self._pre_tool_executor.config = self._config - self._post_tool_executor.config = self._config - self._resource_pre_executor.config = self._config - self._resource_post_executor.config = self._config - - # Initialize context tracking if not already done - if not hasattr(self, "_context_store"): - self._context_store = {} - self._last_cleanup = time.time() + self._executor.config = self._config + self._executor.timeout = timeout @property def config(self) -> Config | None: @@ -673,20 +528,20 @@ async def initialize(self) -> None: if plugin: self._registry.register(plugin) loaded_count += 1 - logger.info(f"Loaded plugin: {plugin_config.name} (mode: {plugin_config.mode})") + logger.info("Loaded plugin: %s (mode: %s)", plugin_config.name, plugin_config.mode) else: raise ValueError(f"Unable to instantiate plugin: {plugin_config.name}") else: - logger.info(f"Plugin: {plugin_config.name} is disabled. Ignoring.") + logger.info("Plugin: %s is disabled. Ignoring.", plugin_config.name) except Exception as e: # Clean error message without stack trace spam - logger.error(f"Failed to load plugin '{plugin_config.name}': {str(e)}") + logger.error("Failed to load plugin %s: {%s}", plugin_config.name, str(e)) # Let it crash gracefully with a clean error - raise RuntimeError(f"Plugin initialization failed: {plugin_config.name} - {str(e)}") + raise RuntimeError(f"Plugin initialization failed: {plugin_config.name} - {str(e)}") from e self._initialized = True - logger.info(f"Plugin manager initialized with {loaded_count} plugins") + logger.info("Plugin manager initialized with %s plugins", loaded_count) async def shutdown(self) -> None: """Shutdown all plugins and cleanup resources. @@ -710,275 +565,31 @@ async def shutdown(self) -> None: await self._registry.shutdown() # Clear context store - self._context_store.clear() # Reset state self._initialized = False logger.info("Plugin manager shutdown complete") - async def _cleanup_old_contexts(self) -> None: - """Remove contexts older than CONTEXT_MAX_AGE to prevent memory leaks. - - This method is called periodically during hook execution to clean up - stale contexts that are no longer needed. - """ - current_time = time.time() - - # Only cleanup every CONTEXT_CLEANUP_INTERVAL seconds - if current_time - self._last_cleanup < CONTEXT_CLEANUP_INTERVAL: - return - - # Find expired contexts - expired_keys = [key for key, (_, timestamp) in self._context_store.items() if current_time - timestamp > CONTEXT_MAX_AGE] - - # Remove expired contexts - for key in expired_keys: - del self._context_store[key] - - if expired_keys: - logger.info(f"Cleaned up {len(expired_keys)} expired plugin contexts") - - self._last_cleanup = current_time - - async def prompt_pre_fetch( - self, payload: PromptPrehookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[PromptPrehookResult, PluginContextTable | None]: - """Execute pre-fetch hooks before a prompt is retrieved and rendered. - - Args: - payload: The prompt payload containing name and arguments. - global_context: Shared context for all plugins with request metadata. - local_contexts: Optional existing contexts from previous executions. - violations_as_exceptions: Raise violations as exceptions rather than as returns. - - Returns: - A tuple containing: - - PromptPrehookResult with processing status and modified payload - - PluginContextTable with updated contexts for post-fetch hook - - Raises: - PayloadSizeError: If payload exceeds size limits. - - Examples: - >>> manager = PluginManager("plugins/config.yaml") - >>> # In async context: - >>> # await manager.initialize() - >>> - >>> from mcpgateway.plugins.framework import PromptPrehookPayload, GlobalContext - >>> payload = PromptPrehookPayload( - ... prompt_id="123", - ... name="greeting", - ... args={"user": "Alice"} - ... ) - >>> context = GlobalContext( - ... request_id="req-123", - ... user="alice@example.com" - ... ) - >>> - >>> # In async context: - >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) - >>> # if result.continue_processing: - >>> # # Proceed with prompt processing - >>> # modified_payload = result.modified_payload or payload - """ - # Cleanup old contexts periodically - await self._cleanup_old_contexts() - - # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) - - # Execute plugins - result = await self._pre_prompt_executor.execute(plugins, payload, global_context, pre_prompt_fetch, pre_prompt_matches, local_contexts, violations_as_exceptions) - - # Store contexts for potential reuse - if result[1]: - self._context_store[global_context.request_id] = (result[1], time.time()) - - return result - - async def prompt_post_fetch( - self, payload: PromptPosthookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[PromptPosthookResult, PluginContextTable | None]: - """Execute post-fetch hooks after a prompt is rendered. - - Args: - payload: The prompt result payload containing rendered messages. - global_context: Shared context for all plugins with request metadata. - local_contexts: Optional contexts from pre-fetch hook execution. - violations_as_exceptions: Raise violations as exceptions rather than as returns. - - Returns: - A tuple containing: - - PromptPosthookResult with processing status and modified result - - PluginContextTable with final contexts - - Raises: - PayloadSizeError: If payload exceeds size limits. - - Examples: - >>> # Continuing from prompt_pre_fetch example - >>> from mcpgateway.models import PromptResult, Message, TextContent, Role - >>> from mcpgateway.plugins.framework import PromptPosthookPayload, GlobalContext - >>> - >>> # Create a proper Message with TextContent - >>> message = Message( - ... role=Role.USER, - ... content=TextContent(type="text", text="Hello") - ... ) - >>> prompt_result = PromptResult(messages=[message]) - >>> - >>> post_payload = PromptPosthookPayload( - ... prompt_id="123", - ... result=prompt_result - ... ) - >>> - >>> manager = PluginManager("plugins/config.yaml") - >>> context = GlobalContext(request_id="req-123") - >>> - >>> # In async context: - >>> # result, _ = await manager.prompt_post_fetch( - >>> # post_payload, - >>> # context, - >>> # contexts # From pre_fetch - >>> # ) - >>> # if result.modified_payload: - >>> # # Use modified result - >>> # final_result = result.modified_payload.result - """ - # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_POST_FETCH) - - # Execute plugins - result = await self._post_prompt_executor.execute(plugins, payload, global_context, post_prompt_fetch, post_prompt_matches, local_contexts, violations_as_exceptions) - - # Clean up stored context after post-fetch - if global_context.request_id in self._context_store: - del self._context_store[global_context.request_id] - - return result - - async def tool_pre_invoke( - self, payload: ToolPreInvokePayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[ToolPreInvokeResult, PluginContextTable | None]: - """Execute pre-invoke hooks before a tool is invoked. - - Args: - payload: The tool payload containing name and arguments. - global_context: Shared context for all plugins with request metadata. - local_contexts: Optional existing contexts from previous executions. - violations_as_exceptions: Raise violations as exceptions rather than as returns. - - Returns: - A tuple containing: - - ToolPreInvokeResult with processing status and modified payload - - PluginContextTable with updated contexts for post-invoke hook - - Raises: - PayloadSizeError: If payload exceeds size limits. - - Examples: - >>> manager = PluginManager("plugins/config.yaml") - >>> # In async context: - >>> # await manager.initialize() - >>> - >>> from mcpgateway.plugins.framework import ToolPreInvokePayload, GlobalContext - >>> payload = ToolPreInvokePayload( - ... name="calculator", - ... args={"operation": "add", "a": 5, "b": 3} - ... ) - >>> context = GlobalContext( - ... request_id="req-123", - ... user="alice@example.com" - ... ) - >>> - >>> # In async context: - >>> # result, contexts = await manager.tool_pre_invoke(payload, context) - >>> # if result.continue_processing: - >>> # # Proceed with tool invocation - >>> # modified_payload = result.modified_payload or payload - """ - # Cleanup old contexts periodically - await self._cleanup_old_contexts() - - # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.TOOL_PRE_INVOKE) - - # Execute plugins - result = await self._pre_tool_executor.execute(plugins, payload, global_context, pre_tool_invoke, pre_tool_matches, local_contexts, violations_as_exceptions) - - # Store contexts for potential reuse - if result[1]: - self._context_store[global_context.request_id] = (result[1], time.time()) - - return result - - async def tool_post_invoke( - self, payload: ToolPostInvokePayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[ToolPostInvokeResult, PluginContextTable | None]: - """Execute post-invoke hooks after a tool is invoked. - - Args: - payload: The tool result payload containing invocation results. - global_context: Shared context for all plugins with request metadata. - local_contexts: Optional contexts from pre-invoke hook execution. - violations_as_exceptions: Raise violations as exceptions rather than as returns. - - Returns: - A tuple containing: - - ToolPostInvokeResult with processing status and modified result - - PluginContextTable with final contexts - - Raises: - PayloadSizeError: If payload exceeds size limits. - - Examples: - >>> # Continuing from tool_pre_invoke example - >>> from mcpgateway.plugins.framework import ToolPostInvokePayload, GlobalContext - >>> - >>> post_payload = ToolPostInvokePayload( - ... name="calculator", - ... result={"result": 8, "status": "success"} - ... ) - >>> - >>> manager = PluginManager("plugins/config.yaml") - >>> context = GlobalContext(request_id="req-123") - >>> - >>> # In async context: - >>> # result, _ = await manager.tool_post_invoke( - >>> # post_payload, - >>> # context, - >>> # contexts # From pre_invoke - >>> # ) - >>> # if result.modified_payload: - >>> # # Use modified result - >>> # final_result = result.modified_payload.result - """ - # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.TOOL_POST_INVOKE) - - # Execute plugins - result = await self._post_tool_executor.execute(plugins, payload, global_context, post_tool_invoke, post_tool_matches, local_contexts, violations_as_exceptions) - - # Clean up stored context after post-invoke - if global_context.request_id in self._context_store: - del self._context_store[global_context.request_id] - - return result - - async def resource_pre_fetch( - self, payload: ResourcePreFetchPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[ResourcePreFetchResult, PluginContextTable | None]: - """Execute pre-fetch hooks before a resource is fetched. + async def invoke_hook( + self, + hook_type: str, + payload: PluginPayload, + global_context: GlobalContext, + local_contexts: Optional[PluginContextTable] = None, + violations_as_exceptions: bool = False, + ) -> tuple[PluginResult, PluginContextTable | None]: + """Invoke a set of plugins configured for the hook point in priority order. Args: - payload: The resource payload containing URI and metadata. + hook_type: The type of hook to execute. + payload: The plugin payload for which the plugins will analyze and modify. global_context: Shared context for all plugins with request metadata. local_contexts: Optional existing contexts from previous hook executions. violations_as_exceptions: Raise violations as exceptions rather than as returns. Returns: A tuple containing: - - ResourcePreFetchResult with processing status and modified payload + - PluginResult with processing status and modified payload - PluginContextTable with plugin contexts for state management Examples: @@ -993,58 +604,72 @@ async def resource_pre_fetch( >>> # uri = result.modified_payload.uri """ # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.RESOURCE_PRE_FETCH) + hook_refs = self._registry.get_hook_refs_for_hook(hook_type=hook_type) # Execute plugins - result = await self._resource_pre_executor.execute(plugins, payload, global_context, pre_resource_fetch, pre_resource_matches, local_contexts, violations_as_exceptions) - - # Store context for potential post-fetch - if result[1]: - self._context_store[global_context.request_id] = (result[1], time.time()) - - # Periodic cleanup - await self._cleanup_old_contexts() + result = await self._executor.execute(hook_refs, payload, global_context, hook_type, local_contexts, violations_as_exceptions) return result - async def resource_post_fetch( - self, payload: ResourcePostFetchPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[ResourcePostFetchResult, PluginContextTable | None]: - """Execute post-fetch hooks after a resource is fetched. + async def invoke_hook_for_plugin( + self, + name: str, + hook_type: str, + payload: Union[PluginPayload, dict[str, Any], str], + context: PluginContext, + violations_as_exceptions: bool = False, + payload_as_json: bool = False, + ) -> PluginResult: + """Invoke a specific hook for a single named plugin. + + This method allows direct invocation of a particular plugin's hook by name, + bypassing the normal priority-ordered execution. Useful for testing individual + plugins or when specific plugin behavior needs to be triggered independently. Args: - payload: The resource content payload containing fetched data. - global_context: Shared context for all plugins with request metadata. - local_contexts: Optional contexts from pre-fetch hook execution. - violations_as_exceptions: Raise violations as exceptions rather than as returns. + name: The name of the plugin to invoke. + hook_type: The type of hook to execute (e.g., "prompt_pre_fetch"). + payload: The plugin payload to be processed by the hook. + context: Plugin execution context with local and global state. + violations_as_exceptions: Raise violations as exceptions rather than returns. + payload_as_json: payload passed in as json rather than pydantic. Returns: - A tuple containing: - - ResourcePostFetchResult with processing status and modified content - - PluginContextTable with updated plugin contexts + PluginResult with processing status, modified payload, and metadata. + + Raises: + PluginError: If the plugin or hook type cannot be found in the registry. + ValueError: If payload type does not match payload_as_json setting. Examples: >>> manager = PluginManager("plugins/config.yaml") >>> # In async context: >>> # await manager.initialize() - >>> # from mcpgateway.models import ResourceContent - >>> # content = ResourceContent(type="resource",id="res-1", uri="file:///data.txt", text="Data") - >>> # payload = ResourcePostFetchPayload("file:///data.txt", content) - >>> # context = GlobalContext(request_id="123", server_id="srv1") - >>> # contexts = self._context_store.get("123") # From pre-fetch - >>> # result, _ = await manager.resource_post_fetch(payload, context, contexts) - >>> # if result.continue_processing: - >>> # # Use modified result - >>> # final_content = result.modified_payload.content + >>> # payload = PromptPrehookPayload(name="test", args={}) + >>> # context = PluginContext(global_context=GlobalContext(request_id="123")) + >>> # result = await manager.invoke_hook_for_plugin( + >>> # name="auth_plugin", + >>> # hook_type="prompt_pre_fetch", + >>> # payload=payload, + >>> # context=context + >>> # ) """ - # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.RESOURCE_POST_FETCH) - - # Execute plugins - result = await self._resource_post_executor.execute(plugins, payload, global_context, post_resource_fetch, post_resource_matches, local_contexts, violations_as_exceptions) - - # Clean up stored context after post-fetch - if global_context.request_id in self._context_store: - del self._context_store[global_context.request_id] - - return result + hook_ref = self._registry.get_plugin_hook_by_name(name, hook_type) + if not hook_ref: + raise PluginError( + error=PluginErrorModel( + message=f"Unable to find {hook_type} for plugin {name}. Make sure the plugin is registered.", + plugin_name=name, + ) + ) + if payload_as_json: + plugin = hook_ref.plugin_ref.plugin + # When payload_as_json=True, payload should be str or dict + if isinstance(payload, (str, dict)): + pydantic_payload = plugin.json_to_payload(hook_type, payload) + return await self._executor.execute_plugin(hook_ref, pydantic_payload, context, violations_as_exceptions) + raise ValueError(f"When payload_as_json=True, payload must be str or dict, got {type(payload)}") + # When payload_as_json=False, payload should already be a PluginPayload + if not isinstance(payload, PluginPayload): + raise ValueError(f"When payload_as_json=False, payload must be a PluginPayload, got {type(payload)}") + return await self._executor.execute_plugin(hook_ref, payload, context, violations_as_exceptions) diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index a8a3e0c32..d6644abc3 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -13,50 +13,33 @@ from enum import Enum import os from pathlib import Path -from typing import Any, Generic, Optional, Self, TypeVar +from typing import Any, Generic, Optional, Self, TypeAlias, TypeVar # Third-Party -from pydantic import BaseModel, Field, field_serializer, field_validator, model_validator, PrivateAttr, RootModel, ValidationInfo +from pydantic import ( + BaseModel, + Field, + field_serializer, + field_validator, + model_validator, + PrivateAttr, + ValidationInfo, +) # First-Party -from mcpgateway.models import PromptResult -from mcpgateway.plugins.framework.constants import AFTER, EXTERNAL_PLUGIN_TYPE, IGNORE_CONFIG_EXTERNAL, PYTHON_SUFFIX, SCRIPT, URL -from mcpgateway.schemas import TransportType -from mcpgateway.validators import SecurityValidator +from mcpgateway.common.models import TransportType +from mcpgateway.common.validators import SecurityValidator +from mcpgateway.plugins.framework.constants import ( + EXTERNAL_PLUGIN_TYPE, + IGNORE_CONFIG_EXTERNAL, + PYTHON_SUFFIX, + SCRIPT, + URL, +) T = TypeVar("T") -class HookType(str, Enum): - """MCP Forge Gateway hook points. - - Attributes: - prompt_pre_fetch: The prompt pre hook. - prompt_post_fetch: The prompt post hook. - tool_pre_invoke: The tool pre invoke hook. - tool_post_invoke: The tool post invoke hook. - resource_pre_fetch: The resource pre fetch hook. - resource_post_fetch: The resource post fetch hook. - - Examples: - >>> HookType.PROMPT_PRE_FETCH - - >>> HookType.PROMPT_PRE_FETCH.value - 'prompt_pre_fetch' - >>> HookType('prompt_post_fetch') - - >>> list(HookType) # doctest: +ELLIPSIS - [, , , , ...] - """ - - PROMPT_PRE_FETCH = "prompt_pre_fetch" - PROMPT_POST_FETCH = "prompt_post_fetch" - TOOL_PRE_INVOKE = "tool_pre_invoke" - TOOL_POST_INVOKE = "tool_post_invoke" - RESOURCE_PRE_FETCH = "resource_pre_fetch" - RESOURCE_POST_FETCH = "resource_post_fetch" - - class PluginMode(str, Enum): """Plugin modes of operation. @@ -190,6 +173,7 @@ class PluginCondition(BaseModel): tools (Optional[set[str]]): set of tool names. prompts (Optional[set[str]]): set of prompt names. resources (Optional[set[str]]): set of resource URIs. + agents (Optional[set[str]]): set of agent IDs. user_pattern (Optional[list[str]]): list of user patterns. content_types (Optional[list[str]]): list of content types. @@ -210,10 +194,11 @@ class PluginCondition(BaseModel): tools: Optional[set[str]] = None prompts: Optional[set[str]] = None resources: Optional[set[str]] = None + agents: Optional[set[str]] = None user_patterns: Optional[list[str]] = None content_types: Optional[list[str]] = None - @field_serializer("server_ids", "tenant_ids", "tools", "prompts") + @field_serializer("server_ids", "tenant_ids", "tools", "prompts", "resources", "agents") def serialize_set(self, value: set[str] | None) -> list[str] | None: """Serialize set objects in PluginCondition for MCP. @@ -262,7 +247,7 @@ class MCPTransportTLSConfigBase(BaseModel): ca_bundle: Optional[str] = Field(default=None, description="Path to CA bundle for verification") keyfile_password: Optional[str] = Field(default=None, description="Password for encrypted private key") - @field_validator("ca_bundle", "certfile", "keyfile", mode=AFTER) + @field_validator("ca_bundle", "certfile", "keyfile", mode="after") @classmethod def validate_path(cls, value: Optional[str]) -> Optional[str]: """Expand and validate file paths supplied in TLS configuration. @@ -284,7 +269,7 @@ def validate_path(cls, value: Optional[str]) -> Optional[str]: raise ValueError(f"TLS file path does not exist: {value}") return str(expanded) - @model_validator(mode=AFTER) + @model_validator(mode="after") def validate_cert_key(self) -> Self: # pylint: disable=bad-classmethod-argument """Ensure certificate and key options are consistent. @@ -421,7 +406,7 @@ class MCPServerConfig(BaseModel): tls (Optional[MCPServerTLSConfig]): Server-side TLS configuration. """ - host: str = Field(default="0.0.0.0", description="Server host to bind to") # nosec B104 + host: str = Field(default="127.0.0.1", description="Server host to bind to") port: int = Field(default=8000, description="Server port to bind to") tls: Optional[MCPServerTLSConfig] = Field(default=None, description="Server-side TLS configuration") @@ -499,7 +484,7 @@ class MCPClientConfig(BaseModel): script: Optional[str] = None tls: Optional[MCPClientTLSConfig] = None - @field_validator(URL, mode=AFTER) + @field_validator(URL, mode="after") @classmethod def validate_url(cls, url: str | None) -> str | None: """Validate a MCP url for streamable HTTP connections. @@ -518,7 +503,7 @@ def validate_url(cls, url: str | None) -> str | None: return result return url - @field_validator(SCRIPT, mode=AFTER) + @field_validator(SCRIPT, mode="after") @classmethod def validate_script(cls, script: str | None) -> str | None: """Validate an MCP stdio script. @@ -542,7 +527,7 @@ def validate_script(cls, script: str | None) -> str | None: raise ValueError(f"MCP server script {script} must have a .py or .sh suffix.") return script - @model_validator(mode=AFTER) + @model_validator(mode="after") def validate_tls_usage(self) -> Self: # pylint: disable=bad-classmethod-argument """Ensure TLS configuration is only used with HTTP-based transports. @@ -568,10 +553,10 @@ class PluginConfig(BaseModel): kind (str): The kind or type of plugin. Usually a fully qualified object type. namespace (str): The namespace where the plugin resides. version (str): version of the plugin. - hooks (list[str]): a list of the hook points where the plugin will be called. + hooks (list[str]): a list of the hook points where the plugin will be called. Default: []. tags (list[str]): a list of tags for making the plugin searchable. mode (bool): whether the plugin is active. - priority (int): indicates the order in which the plugin is run. Lower = higher priority. + priority (int): indicates the order in which the plugin is run. Lower = higher priority. Default: 100. conditions (Optional[list[PluginCondition]]): the conditions on which the plugin is run. applied_to (Optional[list[AppliedTo]]): the tools, fields, that the plugin is applied to. config (dict[str, Any]): the plugin specific configurations. @@ -584,16 +569,16 @@ class PluginConfig(BaseModel): kind: str namespace: Optional[str] = None version: Optional[str] = None - hooks: Optional[list[HookType]] = None - tags: Optional[list[str]] = None + hooks: list[str] = Field(default_factory=list) + tags: list[str] = Field(default_factory=list) mode: PluginMode = PluginMode.ENFORCE - priority: Optional[int] = None # Lower = higher priority - conditions: Optional[list[PluginCondition]] = None # When to apply + priority: int = 100 # Lower = higher priority + conditions: list[PluginCondition] = Field(default_factory=list) # When to apply applied_to: Optional[AppliedTo] = None # Fields to apply to. config: Optional[dict[str, Any]] = None mcp: Optional[MCPClientConfig] = None - @model_validator(mode=AFTER) + @model_validator(mode="after") def check_url_or_script_filled(self) -> Self: # pylint: disable=bad-classmethod-argument """Checks to see that at least one of url or script are set depending on MCP server configuration. @@ -613,7 +598,7 @@ def check_url_or_script_filled(self) -> Self: # pylint: disable=bad-classmethod raise ValueError(f"Plugin {self.name} must set transport type to either SSE or STREAMABLEHTTP or STDIO") return self - @model_validator(mode=AFTER) + @model_validator(mode="after") def check_config_and_external(self, info: ValidationInfo) -> Self: # pylint: disable=bad-classmethod-argument """Checks to see that a plugin's 'config' section is not defined if the kind is 'external'. This is because developers cannot override items in the plugin config section for external plugins. @@ -671,9 +656,9 @@ class PluginErrorModel(BaseModel): """ message: str + plugin_name: str code: Optional[str] = "" details: Optional[dict[str, Any]] = Field(default_factory=dict) - plugin_name: str mcp_error_code: int = -32603 @@ -707,7 +692,7 @@ class PluginViolation(BaseModel): reason: str description: str code: str - details: dict[str, Any] + details: Optional[dict[str, Any]] = Field(default_factory=dict) _plugin_name: str = PrivateAttr(default="") mcp_error_code: Optional[int] = None @@ -769,61 +754,6 @@ class Config(BaseModel): server_settings: Optional[MCPServerConfig] = None -class PromptPrehookPayload(BaseModel): - """A prompt payload for a prompt prehook. - - Attributes: - prompt_id (str): The ID of the prompt template. - args (dic[str,str]): The prompt template arguments. - - Examples: - >>> payload = PromptPrehookPayload(prompt_id="123", args={"user": "alice"}) - >>> payload.prompt_id - '123' - >>> payload.args - {'user': 'alice'} - >>> payload2 = PromptPrehookPayload(prompt_id="empty") - >>> payload2.args - {} - >>> p = PromptPrehookPayload(prompt_id="123", args={"name": "Bob", "time": "morning"}) - >>> p.prompt_id - '123' - >>> p.args["name"] - 'Bob' - """ - - prompt_id: str - args: Optional[dict[str, str]] = Field(default_factory=dict) - - -class PromptPosthookPayload(BaseModel): - """A prompt payload for a prompt posthook. - - Attributes: - prompt_id (str): The prompt ID. - result (PromptResult): The prompt after its template is rendered. - - Examples: - >>> from mcpgateway.models import PromptResult, Message, TextContent - >>> msg = Message(role="user", content=TextContent(type="text", text="Hello World")) - >>> result = PromptResult(messages=[msg]) - >>> payload = PromptPosthookPayload(prompt_id="123", result=result) - >>> payload.prompt_id - '123' - >>> payload.result.messages[0].content.text - 'Hello World' - >>> from mcpgateway.models import PromptResult, Message, TextContent - >>> msg = Message(role="assistant", content=TextContent(type="text", text="Test output")) - >>> r = PromptResult(messages=[msg]) - >>> p = PromptPosthookPayload(prompt_id="123", result=r) - >>> p.prompt_id - '123' - """ - - prompt_id: str - result: PromptResult - - class PluginResult(BaseModel, Generic[T]): """A result of the plugin hook processing. The actual type is dependent on the hook. @@ -862,111 +792,6 @@ class PluginResult(BaseModel, Generic[T]): metadata: Optional[dict[str, Any]] = Field(default_factory=dict) -PromptPrehookResult = PluginResult[PromptPrehookPayload] -PromptPosthookResult = PluginResult[PromptPosthookPayload] - - -class HttpHeaderPayload(RootModel[dict[str, str]]): - """An HTTP dictionary of headers used in the pre/post HTTP forwarding hooks.""" - - def __iter__(self): - """Custom iterator function to override root attribute. - - Returns: - A custom iterator for header dictionary. - """ - return iter(self.root) - - def __getitem__(self, item: str) -> str: - """Custom getitem function to override root attribute. - - Args: - item: The http header key. - - Returns: - A custom accesser for the header dictionary. - """ - return self.root[item] - - def __setitem__(self, key: str, value: str) -> None: - """Custom setitem function to override root attribute. - - Args: - key: The http header key. - value: The http header value to be set. - """ - self.root[key] = value - - def __len__(self): - """Custom len function to override root attribute. - - Returns: - The len of the header dictionary. - """ - return len(self.root) - - -HttpHeaderPayloadResult = PluginResult[HttpHeaderPayload] - - -class ToolPreInvokePayload(BaseModel): - """A tool payload for a tool pre-invoke hook. - - Args: - name: The tool name. - args: The tool arguments for invocation. - headers: The http pass through headers. - - Examples: - >>> payload = ToolPreInvokePayload(name="test_tool", args={"input": "data"}) - >>> payload.name - 'test_tool' - >>> payload.args - {'input': 'data'} - >>> payload2 = ToolPreInvokePayload(name="empty") - >>> payload2.args - {} - >>> p = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3}) - >>> p.name - 'calculator' - >>> p.args["operation"] - 'add' - - """ - - name: str - args: Optional[dict[str, Any]] = Field(default_factory=dict) - headers: Optional[HttpHeaderPayload] = None - - -class ToolPostInvokePayload(BaseModel): - """A tool payload for a tool post-invoke hook. - - Args: - name: The tool name. - result: The tool invocation result. - - Examples: - >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"}) - >>> payload.name - 'calculator' - >>> payload.result - {'result': 8, 'status': 'success'} - >>> p = ToolPostInvokePayload(name="analyzer", result={"confidence": 0.95, "sentiment": "positive"}) - >>> p.name - 'analyzer' - >>> p.result["confidence"] - 0.95 - """ - - name: str - result: Any - - -ToolPreInvokeResult = PluginResult[ToolPreInvokePayload] -ToolPostInvokeResult = PluginResult[ToolPostInvokePayload] - - class GlobalContext(BaseModel): """The global context, which shared across all plugins. @@ -1065,58 +890,4 @@ def is_empty(self) -> bool: PluginContextTable = dict[str, PluginContext] - -class ResourcePreFetchPayload(BaseModel): - """A resource payload for a resource pre-fetch hook. - - Attributes: - uri: The resource URI. - metadata: Optional metadata for the resource request. - - Examples: - >>> payload = ResourcePreFetchPayload(uri="file:///data.txt") - >>> payload.uri - 'file:///data.txt' - >>> payload2 = ResourcePreFetchPayload(uri="http://api/data", metadata={"Accept": "application/json"}) - >>> payload2.metadata - {'Accept': 'application/json'} - >>> p = ResourcePreFetchPayload(uri="file:///docs/readme.md", metadata={"version": "1.0"}) - >>> p.uri - 'file:///docs/readme.md' - >>> p.metadata["version"] - '1.0' - """ - - uri: str - metadata: Optional[dict[str, Any]] = Field(default_factory=dict) - - -class ResourcePostFetchPayload(BaseModel): - """A resource payload for a resource post-fetch hook. - - Attributes: - uri: The resource URI. - content: The fetched resource content. - - Examples: - >>> from mcpgateway.models import ResourceContent - >>> content = ResourceContent(type="resource", id="res-1", uri="file:///data.txt", - ... text="Hello World") - >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) - >>> payload.uri - 'file:///data.txt' - >>> payload.content.text - 'Hello World' - >>> from mcpgateway.models import ResourceContent - >>> resource_content = ResourceContent(type="resource", id="res-2", uri="test://resource", text="Test data") - >>> p = ResourcePostFetchPayload(uri="test://resource", content=resource_content) - >>> p.uri - 'test://resource' - """ - - uri: str - content: Any - - -ResourcePreFetchResult = PluginResult[ResourcePreFetchPayload] -ResourcePostFetchResult = PluginResult[ResourcePostFetchPayload] +PluginPayload: TypeAlias = BaseModel diff --git a/mcpgateway/plugins/framework/registry.py b/mcpgateway/plugins/framework/registry.py index 519c26ada..28c5259dc 100644 --- a/mcpgateway/plugins/framework/registry.py +++ b/mcpgateway/plugins/framework/registry.py @@ -14,8 +14,8 @@ from typing import Optional # First-Party -from mcpgateway.plugins.framework.base import Plugin, PluginRef -from mcpgateway.plugins.framework.models import HookType +from mcpgateway.plugins.framework.base import HookRef, Plugin, PluginRef +from mcpgateway.plugins.framework.external.mcp.client import ExternalHookRef, ExternalPlugin # Use standard logging to avoid circular imports (plugins -> services -> plugins) logger = logging.getLogger(__name__) @@ -25,7 +25,8 @@ class PluginInstanceRegistry: """Registry for managing loaded plugins. Examples: - >>> from mcpgateway.plugins.framework import Plugin, PluginConfig, HookType + >>> from mcpgateway.plugins.framework import Plugin, PluginConfig + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> registry = PluginInstanceRegistry() >>> config = PluginConfig( ... name="test", @@ -33,14 +34,16 @@ class PluginInstanceRegistry: ... author="test", ... kind="test.Plugin", ... version="1.0", - ... hooks=[HookType.PROMPT_PRE_FETCH], + ... hooks=[PromptHookType.PROMPT_PRE_FETCH], ... tags=[] ... ) + >>> async def prompt_pre_fetch(payload, context): ... >>> plugin = Plugin(config) + >>> plugin.prompt_pre_fetch = prompt_pre_fetch >>> registry.register(plugin) >>> registry.get_plugin("test").name 'test' - >>> len(registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH)) + >>> len(registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH)) 1 >>> registry.unregister("test") >>> registry.get_plugin("test") is None @@ -60,8 +63,9 @@ def __init__(self) -> None: 0 """ self._plugins: dict[str, PluginRef] = {} - self._hooks: dict[HookType, list[PluginRef]] = defaultdict(list) - self._priority_cache: dict[HookType, list[PluginRef]] = {} + self._hooks: dict[str, list[HookRef]] = defaultdict(list) + self._hooks_by_name: dict[str, dict[str, HookRef]] = {} + self._priority_cache: dict[str, list[HookRef]] = {} def register(self, plugin: Plugin) -> None: """Register a plugin instance. @@ -79,13 +83,24 @@ def register(self, plugin: Plugin) -> None: self._plugins[plugin.name] = plugin_ref + plugin_hooks = {} + + external = isinstance(plugin, ExternalPlugin) + # Register hooks for hook_type in plugin.hooks: - self._hooks[hook_type].append(plugin_ref) + hook_ref: HookRef + if external: + hook_ref = ExternalHookRef(hook_type, plugin_ref) + else: + hook_ref = HookRef(hook_type, plugin_ref) + self._hooks[hook_type].append(hook_ref) + plugin_hooks[hook_type] = hook_ref # Invalidate priority cache for this hook self._priority_cache.pop(hook_type, None) + self._hooks_by_name[plugin.name] = plugin_hooks - logger.info(f"Registered plugin: {plugin.name} with hooks: {[h.name for h in plugin.hooks]}") + logger.info(f"Registered plugin: {plugin.name} with hooks: {list(plugin.hooks)}") def unregister(self, plugin_name: str) -> None: """Unregister a plugin given its name. @@ -102,9 +117,12 @@ def unregister(self, plugin_name: str) -> None: plugin = self._plugins.pop(plugin_name) # Remove from hooks for hook_type in plugin.hooks: - self._hooks[hook_type] = [p for p in self._hooks[hook_type] if p.name != plugin_name] + self._hooks[hook_type] = [p for p in self._hooks[hook_type] if p.plugin_ref.name != plugin_name] self._priority_cache.pop(hook_type, None) + # Remove from hooks by name + self._hooks_by_name.pop(plugin_name, None) + logger.info(f"Unregistered plugin: {plugin_name}") def get_plugin(self, name: str) -> Optional[PluginRef]: @@ -118,7 +136,23 @@ def get_plugin(self, name: str) -> Optional[PluginRef]: """ return self._plugins.get(name) - def get_plugins_for_hook(self, hook_type: HookType) -> list[PluginRef]: + def get_plugin_hook_by_name(self, name: str, hook_type: str) -> Optional[HookRef]: + """Gets a hook reference for a particular plugin and hook type. + + Args: + name: plugin name. + hook_type: the hook type. + + Returns: + A hook reference for the plugin or None if not found. + """ + if name in self._hooks_by_name: + hooks = self._hooks_by_name[name] + if hook_type in hooks: + return hooks[hook_type] + return None + + def get_hook_refs_for_hook(self, hook_type: str) -> list[HookRef]: """Get all plugins for a specific hook, sorted by priority. Args: @@ -128,8 +162,8 @@ def get_plugins_for_hook(self, hook_type: HookType) -> list[PluginRef]: A list of plugin instances. """ if hook_type not in self._priority_cache: - plugins = sorted(self._hooks[hook_type], key=lambda p: p.priority) - self._priority_cache[hook_type] = plugins + hook_refs = sorted(self._hooks[hook_type], key=lambda p: p.plugin_ref.priority) + self._priority_cache[hook_type] = hook_refs return self._priority_cache[hook_type] def get_all_plugins(self) -> list[PluginRef]: diff --git a/mcpgateway/plugins/framework/utils.py b/mcpgateway/plugins/framework/utils.py index 17f561fb1..0d40e01ac 100644 --- a/mcpgateway/plugins/framework/utils.py +++ b/mcpgateway/plugins/framework/utils.py @@ -13,19 +13,23 @@ from functools import cache import importlib from types import ModuleType +from typing import Any, Optional # First-Party from mcpgateway.plugins.framework.models import ( GlobalContext, PluginCondition, - PromptPosthookPayload, - PromptPrehookPayload, - ResourcePostFetchPayload, - ResourcePreFetchPayload, - ToolPostInvokePayload, - ToolPreInvokePayload, ) +# from mcpgateway.plugins.mcp.entities import ( +# PromptPosthookPayload, +# PromptPrehookPayload, +# ResourcePostFetchPayload, +# ResourcePreFetchPayload, +# ToolPostInvokePayload, +# ToolPreInvokePayload, +# ) + @cache # noqa def import_module(mod_name: str) -> ModuleType: @@ -111,208 +115,326 @@ def matches(condition: PluginCondition, context: GlobalContext) -> bool: return True -def pre_prompt_matches(payload: PromptPrehookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on pre-prompt hooks. - - Args: - payload: the prompt prehook payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. - - Returns: - True if the plugin matches criteria. - - Examples: - >>> from mcpgateway.plugins.framework import PluginCondition, PromptPrehookPayload, GlobalContext - >>> payload = PromptPrehookPayload(prompt_id="id1", args={}) - >>> cond = PluginCondition(prompts={"id1"}) - >>> ctx = GlobalContext(request_id="req1") - >>> pre_prompt_matches(payload, [cond], ctx) - True - >>> payload2 = PromptPrehookPayload(prompt_id="id2", args={}) - >>> pre_prompt_matches(payload2, [cond], ctx) - False - """ - current_result = True - for index, condition in enumerate(conditions): - if not matches(condition, context): - current_result = False - - if condition.prompts and payload.prompt_id not in condition.prompts: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result - - -def post_prompt_matches(payload: PromptPosthookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on pre-prompt hooks. - - Args: - payload: the prompt posthook payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. - - Returns: - True if the plugin matches criteria. - """ - current_result = True - for index, condition in enumerate(conditions): - if not matches(condition, context): - current_result = False - - if condition.prompts and payload.prompt_id not in condition.prompts: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result - +def get_matchable_value(payload: Any, hook_type: str) -> Optional[str]: + """Extract the matchable value from a payload based on hook type. -def pre_tool_matches(payload: ToolPreInvokePayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on pre-tool hooks. + This function maps hook types to their corresponding payload attributes + that should be used for conditional matching. Args: - payload: the tool pre-invoke payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. + payload: The payload object (e.g., ToolPreInvokePayload, AgentPreInvokePayload). + hook_type: The hook type identifier. Returns: - True if the plugin matches criteria. + The matchable value (e.g., tool name, agent ID, resource URI) or None. Examples: - >>> from mcpgateway.plugins.framework import PluginCondition, ToolPreInvokePayload, GlobalContext + >>> from mcpgateway.plugins.framework import GlobalContext + >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload >>> payload = ToolPreInvokePayload(name="calculator", args={}) - >>> cond = PluginCondition(tools={"calculator"}) - >>> ctx = GlobalContext(request_id="req1") - >>> pre_tool_matches(payload, [cond], ctx) - True - >>> payload2 = ToolPreInvokePayload(name="other", args={}) - >>> pre_tool_matches(payload2, [cond], ctx) - False + >>> get_matchable_value(payload, "tool_pre_invoke") + 'calculator' + >>> get_matchable_value(payload, "unknown_hook") """ - current_result = True - for index, condition in enumerate(conditions): - if not matches(condition, context): - current_result = False - - if condition.tools and payload.name not in condition.tools: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result - - -def post_tool_matches(payload: ToolPostInvokePayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on post-tool hooks. + # Mapping: hook_type -> payload attribute name + field_map = { + "tool_pre_invoke": "name", + "tool_post_invoke": "name", + "prompt_pre_fetch": "prompt_id", + "prompt_post_fetch": "prompt_id", + "resource_pre_fetch": "uri", + "resource_post_fetch": "uri", + "agent_pre_invoke": "agent_id", + "agent_post_invoke": "agent_id", + } + + field_name = field_map.get(hook_type) + if field_name: + return getattr(payload, field_name, None) + return None + + +def payload_matches( + payload: Any, + hook_type: str, + conditions: list[PluginCondition], + context: GlobalContext, +) -> bool: + """Check if a payload matches any of the plugin conditions. + + This function provides generic conditional matching for all hook types. + It checks both GlobalContext conditions (via matches()) and payload-specific + conditions (tools, prompts, resources, agents). Args: - payload: the tool post-invoke payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. + payload: The payload object. + hook_type: The hook type identifier. + conditions: List of conditions to check against. + context: The global context. Returns: - True if the plugin matches criteria. + True if the payload matches any condition or if no conditions are specified. Examples: - >>> from mcpgateway.plugins.framework import PluginCondition, ToolPostInvokePayload, GlobalContext - >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8}) + >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext + >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload + >>> payload = ToolPreInvokePayload(name="calculator", args={}) >>> cond = PluginCondition(tools={"calculator"}) >>> ctx = GlobalContext(request_id="req1") - >>> post_tool_matches(payload, [cond], ctx) + >>> payload_matches(payload, "tool_pre_invoke", [cond], ctx) True - >>> payload2 = ToolPostInvokePayload(name="other", result={"result": 8}) - >>> post_tool_matches(payload2, [cond], ctx) + >>> cond2 = PluginCondition(tools={"other_tool"}) + >>> payload_matches(payload, "tool_pre_invoke", [cond2], ctx) False - """ - current_result = True - for index, condition in enumerate(conditions): - if not matches(condition, context): - current_result = False - - if condition.tools and payload.name not in condition.tools: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result - - -def pre_resource_matches(payload: ResourcePreFetchPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on pre-resource hooks. - - Args: - payload: the resource pre-fetch payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. - - Returns: - True if the plugin matches criteria. - - Examples: - >>> from mcpgateway.plugins.framework import PluginCondition, ResourcePreFetchPayload, GlobalContext - >>> payload = ResourcePreFetchPayload(uri="file:///data.txt") - >>> cond = PluginCondition(resources={"file:///data.txt"}) - >>> ctx = GlobalContext(request_id="req1") - >>> pre_resource_matches(payload, [cond], ctx) + >>> payload_matches(payload, "tool_pre_invoke", [], ctx) True - >>> payload2 = ResourcePreFetchPayload(uri="http://api/other") - >>> pre_resource_matches(payload2, [cond], ctx) - False - """ - current_result = True - for index, condition in enumerate(conditions): - if not matches(condition, context): - current_result = False - - if condition.resources and payload.uri not in condition.resources: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result - - -def post_resource_matches(payload: ResourcePostFetchPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on post-resource hooks. - - Args: - payload: the resource post-fetch payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. - - Returns: - True if the plugin matches criteria. - - Examples: - >>> from mcpgateway.plugins.framework import PluginCondition, ResourcePostFetchPayload, GlobalContext - >>> from mcpgateway.models import ResourceContent - >>> content = ResourceContent(type="resource", id="123", uri="file:///data.txt", text="Test") - >>> payload = ResourcePostFetchPayload(id="123",uri="file:///data.txt", content=content) - >>> cond = PluginCondition(resources={"file:///data.txt"}) - >>> ctx = GlobalContext(request_id="req1") - >>> post_resource_matches(payload, [cond], ctx) - True - >>> payload2 = ResourcePostFetchPayload(uri="http://api/other", content=content) - >>> post_resource_matches(payload2, [cond], ctx) - False """ - current_result = True - for index, condition in enumerate(conditions): + # Mapping: hook_type -> PluginCondition attribute name + condition_attr_map = { + "tool_pre_invoke": "tools", + "tool_post_invoke": "tools", + "prompt_pre_fetch": "prompts", + "prompt_post_fetch": "prompts", + "resource_pre_fetch": "resources", + "resource_post_fetch": "resources", + "agent_pre_invoke": "agents", + "agent_post_invoke": "agents", + } + + # If no conditions, match everything + if not conditions: + return True + + # Check each condition (OR logic between conditions) + for condition in conditions: + # First check GlobalContext conditions if not matches(condition, context): - current_result = False - - if condition.resources and payload.uri not in condition.resources: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result + continue + + # Then check payload-specific conditions + condition_attr = condition_attr_map.get(hook_type) + if condition_attr: + condition_set = getattr(condition, condition_attr, None) + if condition_set: + # Extract the matchable value from the payload + payload_value = get_matchable_value(payload, hook_type) + if payload_value and payload_value not in condition_set: + # Payload value doesn't match this condition's set + continue + + # If we get here, this condition matched + return True + + # No conditions matched + return False + + +# def pre_prompt_matches(payload: PromptPrehookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on pre-prompt hooks. + +# Args: +# payload: the prompt prehook payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. + +# Examples: +# >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext +# >>> from mcpgateway.plugins.mcp.entities import PromptPrehookPayload +# >>> payload = PromptPrehookPayload(name="greeting", args={}) +# >>> cond = PluginCondition(prompts={"greeting"}) +# >>> ctx = GlobalContext(request_id="req1") +# >>> pre_prompt_matches(payload, [cond], ctx) +# True +# >>> payload2 = PromptPrehookPayload(name="other", args={}) +# >>> pre_prompt_matches(payload2, [cond], ctx) +# False +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.prompts and payload.name not in condition.prompts: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result + + +# def post_prompt_matches(payload: PromptPosthookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on pre-prompt hooks. + +# Args: +# payload: the prompt posthook payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.prompts and payload.name not in condition.prompts: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result + + +# def pre_tool_matches(payload: ToolPreInvokePayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on pre-tool hooks. + +# Args: +# payload: the tool pre-invoke payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. + +# Examples: +# >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext +# >>> from mcpgateway.plugins.mcp.entities import ToolPreInvokePayload +# >>> payload = ToolPreInvokePayload(name="calculator", args={}) +# >>> cond = PluginCondition(tools={"calculator"}) +# >>> ctx = GlobalContext(request_id="req1") +# >>> pre_tool_matches(payload, [cond], ctx) +# True +# >>> payload2 = ToolPreInvokePayload(name="other", args={}) +# >>> pre_tool_matches(payload2, [cond], ctx) +# False +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.tools and payload.name not in condition.tools: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result + + +# def post_tool_matches(payload: ToolPostInvokePayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on post-tool hooks. + +# Args: +# payload: the tool post-invoke payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. + +# Examples: +# >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext +# >>> from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload +# >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8}) +# >>> cond = PluginCondition(tools={"calculator"}) +# >>> ctx = GlobalContext(request_id="req1") +# >>> post_tool_matches(payload, [cond], ctx) +# True +# >>> payload2 = ToolPostInvokePayload(name="other", result={"result": 8}) +# >>> post_tool_matches(payload2, [cond], ctx) +# False +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.tools and payload.name not in condition.tools: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result + + +# def pre_resource_matches(payload: ResourcePreFetchPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on pre-resource hooks. + +# Args: +# payload: the resource pre-fetch payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. + +# Examples: +# >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext +# >>> from mcpgateway.plugins.mcp.entities import ResourcePreFetchPayload +# >>> payload = ResourcePreFetchPayload(uri="file:///data.txt") +# >>> cond = PluginCondition(resources={"file:///data.txt"}) +# >>> ctx = GlobalContext(request_id="req1") +# >>> pre_resource_matches(payload, [cond], ctx) +# True +# >>> payload2 = ResourcePreFetchPayload(uri="http://api/other") +# >>> pre_resource_matches(payload2, [cond], ctx) +# False +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.resources and payload.uri not in condition.resources: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result + + +# def post_resource_matches(payload: ResourcePostFetchPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on post-resource hooks. + +# Args: +# payload: the resource post-fetch payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. + +# Examples: +# >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext +# >>> from mcpgateway.plugins.mcp.entities import ResourcePostFetchPayload, ResourceContent +# >>> content = ResourceContent(type="resource", uri="file:///data.txt", text="Test") +# >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) +# >>> cond = PluginCondition(resources={"file:///data.txt"}) +# >>> ctx = GlobalContext(request_id="req1") +# >>> post_resource_matches(payload, [cond], ctx) +# True +# >>> payload2 = ResourcePostFetchPayload(uri="http://api/other", content=content) +# >>> post_resource_matches(payload2, [cond], ctx) +# False +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.resources and payload.uri not in condition.resources: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result diff --git a/mcpgateway/plugins/tools/cli.py b/mcpgateway/plugins/tools/cli.py index 3029cf0d6..01a2b5cd0 100644 --- a/mcpgateway/plugins/tools/cli.py +++ b/mcpgateway/plugins/tools/cli.py @@ -73,7 +73,7 @@ # --------------------------------------------------------------------------- -def command_exists(command_name): +def command_exists(command_name: str) -> bool: """Check if a given command-line utility exists and is executable. Args: @@ -132,7 +132,7 @@ def bootstrap( answers_file: Optional[Annotated[typer.FileText, typer.Option("--answers_file", "-a", help="The answers file to be used for bootstrapping.")]] = None, defaults: Annotated[bool, typer.Option("--defaults", help="Bootstrap with defaults.")] = False, dry_run: Annotated[bool, typer.Option("--dry_run", help="Run but do not make any changes.")] = False, -): +) -> None: """Boostrap a new plugin project from a template. Args: @@ -161,7 +161,7 @@ def bootstrap( @app.callback() -def callback(): # pragma: no cover +def callback() -> None: # pragma: no cover """This function exists to force 'bootstrap' to be a subcommand.""" diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index 01961e234..bc219511e 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -33,16 +33,16 @@ from pydantic import AnyHttpUrl, BaseModel, ConfigDict, EmailStr, Field, field_serializer, field_validator, model_validator, ValidationInfo # First-Party +from mcpgateway.common.models import Annotations, ImageContent +from mcpgateway.common.models import Prompt as MCPPrompt +from mcpgateway.common.models import Resource as MCPResource +from mcpgateway.common.models import ResourceContent, TextContent +from mcpgateway.common.models import Tool as MCPTool +from mcpgateway.common.validators import SecurityValidator from mcpgateway.config import settings -from mcpgateway.models import Annotations, ImageContent -from mcpgateway.models import Prompt as MCPPrompt -from mcpgateway.models import Resource as MCPResource -from mcpgateway.models import ResourceContent, TextContent -from mcpgateway.models import Tool as MCPTool from mcpgateway.utils.base_models import BaseModelWithConfigDict from mcpgateway.utils.services_auth import decode_auth, encode_auth from mcpgateway.validation.tags import validate_tags_field -from mcpgateway.validators import SecurityValidator logger = logging.getLogger(__name__) diff --git a/mcpgateway/services/completion_service.py b/mcpgateway/services/completion_service.py index bee038abd..89b99c9d9 100644 --- a/mcpgateway/services/completion_service.py +++ b/mcpgateway/services/completion_service.py @@ -25,9 +25,9 @@ from sqlalchemy.orm import Session # First-Party +from mcpgateway.common.models import CompleteResult from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import Resource as DbResource -from mcpgateway.models import CompleteResult from mcpgateway.services.logging_service import LoggingService # Initialize logging service first diff --git a/mcpgateway/services/elicitation_service.py b/mcpgateway/services/elicitation_service.py index b226aa87f..095663518 100644 --- a/mcpgateway/services/elicitation_service.py +++ b/mcpgateway/services/elicitation_service.py @@ -18,7 +18,7 @@ from uuid import uuid4 # First-Party -from mcpgateway.models import ElicitResult +from mcpgateway.common.models import ElicitResult logger = logging.getLogger(__name__) diff --git a/mcpgateway/services/log_storage_service.py b/mcpgateway/services/log_storage_service.py index ed4631c9d..36dca4fb1 100644 --- a/mcpgateway/services/log_storage_service.py +++ b/mcpgateway/services/log_storage_service.py @@ -18,8 +18,8 @@ import uuid # First-Party +from mcpgateway.common.models import LogLevel from mcpgateway.config import settings -from mcpgateway.models import LogLevel class LogEntryDict(TypedDict, total=False): @@ -108,7 +108,7 @@ def to_dict(self) -> LogEntryDict: Dictionary representation of the log entry Examples: - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> entry = LogEntry(LogLevel.INFO, "Test message", entity_type="tool", entity_id="123") >>> d = entry.to_dict() >>> str(d['level']) @@ -371,7 +371,7 @@ def _meets_level_threshold(self, log_level: LogLevel, min_level: LogLevel) -> bo True if log level meets or exceeds minimum Examples: - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> service = LogStorageService() >>> service._meets_level_threshold(LogLevel.ERROR, LogLevel.WARNING) True @@ -462,7 +462,7 @@ def clear(self) -> int: Number of logs cleared Examples: - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> service = LogStorageService() >>> import asyncio >>> entry = asyncio.run(service.add_log(LogLevel.INFO, "Test")) diff --git a/mcpgateway/services/logging_service.py b/mcpgateway/services/logging_service.py index 1d3191791..4f21111c0 100644 --- a/mcpgateway/services/logging_service.py +++ b/mcpgateway/services/logging_service.py @@ -22,8 +22,8 @@ from pythonjsonlogger import json as jsonlogger # You may need to install python-json-logger package # First-Party +from mcpgateway.common.models import LogLevel from mcpgateway.config import settings -from mcpgateway.models import LogLevel from mcpgateway.services.log_storage_service import LogStorageService AnyioClosedResourceError: Optional[type] # pylint: disable=invalid-name @@ -414,7 +414,7 @@ async def set_level(self, level: LogLevel) -> None: Examples: >>> from mcpgateway.services.logging_service import LoggingService - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> import asyncio >>> service = LoggingService() >>> asyncio.run(service.set_level(LogLevel.DEBUG)) @@ -454,7 +454,7 @@ async def notify( # pylint: disable=too-many-positional-arguments Examples: >>> from mcpgateway.services.logging_service import LoggingService - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> import asyncio >>> service = LoggingService() >>> asyncio.run(service.notify('test', LogLevel.INFO)) @@ -547,7 +547,7 @@ def _should_log(self, level: LogLevel) -> bool: True if should log Examples: - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> service = LoggingService() >>> service._level = LogLevel.WARNING >>> service._should_log(LogLevel.ERROR) diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index ef169bade..8cb936b05 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -30,13 +30,13 @@ from sqlalchemy.orm import Session # First-Party +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.config import settings from mcpgateway.db import EmailTeam from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import PromptMetric, server_prompt_association -from mcpgateway.models import Message, PromptResult, Role, TextContent from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import GlobalContext, PluginManager, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework import GlobalContext, PluginManager, PromptHookType, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.observability_service import current_trace_id, ObservabilityService @@ -750,8 +750,12 @@ async def get_prompt( if not request_id: request_id = uuid.uuid4().hex global_context = GlobalContext(request_id=request_id, user=user, server_id=server_id, tenant_id=tenant_id) - pre_result, context_table = await self._plugin_manager.prompt_pre_fetch( - payload=PromptPrehookPayload(prompt_id=str(prompt_id), args=arguments), global_context=global_context, local_contexts=None, violations_as_exceptions=True + pre_result, context_table = await self._plugin_manager.invoke_hook( + PromptHookType.PROMPT_PRE_FETCH, + payload=PromptPrehookPayload(prompt_id=str(prompt_id), args=arguments), + global_context=global_context, + local_contexts=None, + violations_as_exceptions=True, ) # Use modified payload if provided @@ -815,8 +819,12 @@ async def get_prompt( raise PromptError(f"Failed to process prompt: {str(e)}") if self._plugin_manager: - post_result, _ = await self._plugin_manager.prompt_post_fetch( - payload=PromptPosthookPayload(prompt_id=str(prompt.id), result=result), global_context=global_context, local_contexts=context_table, violations_as_exceptions=True + post_result, _ = await self._plugin_manager.invoke_hook( + PromptHookType.PROMPT_POST_FETCH, + payload=PromptPosthookPayload(prompt_id=str(prompt.id), result=result), + global_context=global_context, + local_contexts=context_table, + violations_as_exceptions=True, ) # Use modified payload if provided result = post_result.modified_payload.result if post_result.modified_payload else result diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index b6bdbb277..97ea6e250 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -41,13 +41,13 @@ from sqlalchemy.orm import Session # First-Party +from mcpgateway.common.models import ResourceContent, ResourceTemplate, TextContent from mcpgateway.config import settings from mcpgateway.db import EmailTeam from mcpgateway.db import Resource as DbResource from mcpgateway.db import ResourceMetric from mcpgateway.db import ResourceSubscription as DbSubscription from mcpgateway.db import server_resource_association -from mcpgateway.models import ResourceContent, ResourceTemplate, TextContent from mcpgateway.observability import create_span from mcpgateway.schemas import ResourceCreate, ResourceMetrics, ResourceRead, ResourceSubscription, ResourceUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService @@ -59,7 +59,7 @@ # Plugin support imports (conditional) try: # First-Party - from mcpgateway.plugins.framework import GlobalContext, PluginManager, ResourcePostFetchPayload, ResourcePreFetchPayload + from mcpgateway.plugins.framework import GlobalContext, PluginManager, ResourceHookType, ResourcePostFetchPayload, ResourcePreFetchPayload PLUGINS_AVAILABLE = True except ImportError: @@ -693,7 +693,7 @@ async def read_resource(self, db: Session, resource_id: Union[int, str], request Examples: >>> from mcpgateway.services.resource_service import ResourceService >>> from unittest.mock import MagicMock - >>> from mcpgateway.models import ResourceContent + >>> from mcpgateway.common.models import ResourceContent >>> service = ResourceService() >>> db = MagicMock() >>> uri = 'http://example.com/resource.txt' @@ -797,7 +797,7 @@ async def read_resource(self, db: Session, resource_id: Union[int, str], request pre_payload = ResourcePreFetchPayload(uri=uri, metadata={}) # Execute pre-fetch hooks - pre_result, contexts = await self._plugin_manager.resource_pre_fetch(pre_payload, global_context, violations_as_exceptions=True) + pre_result, contexts = await self._plugin_manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, pre_payload, global_context, violations_as_exceptions=True) # Use modified URI if plugin changed it if pre_result.modified_payload: uri = pre_result.modified_payload.uri @@ -827,7 +827,9 @@ async def read_resource(self, db: Session, resource_id: Union[int, str], request post_payload = ResourcePostFetchPayload(uri=original_uri, content=content) # Execute post-fetch hooks - post_result, _ = await self._plugin_manager.resource_post_fetch(post_payload, global_context, contexts, violations_as_exceptions=True) # Pass contexts from pre-fetch + post_result, _ = await self._plugin_manager.invoke_hook( + ResourceHookType.RESOURCE_POST_FETCH, post_payload, global_context, contexts, violations_as_exceptions=True + ) # Pass contexts from pre-fetch # Use modified content if plugin changed it if post_result.modified_payload: diff --git a/mcpgateway/services/root_service.py b/mcpgateway/services/root_service.py index 1e88e62e1..3f97b87c7 100644 --- a/mcpgateway/services/root_service.py +++ b/mcpgateway/services/root_service.py @@ -16,8 +16,8 @@ from urllib.parse import urlparse # First-Party +from mcpgateway.common.models import Root from mcpgateway.config import settings -from mcpgateway.models import Root from mcpgateway.services.logging_service import LoggingService # Initialize logging service first @@ -296,7 +296,7 @@ async def _notify_root_added(self, root: Root) -> None: Examples: >>> import asyncio >>> from mcpgateway.services.root_service import RootService - >>> from mcpgateway.models import Root + >>> from mcpgateway.common.models import Root >>> service = RootService() >>> queue = asyncio.Queue() >>> service._subscribers.append(queue) @@ -320,7 +320,7 @@ async def _notify_root_removed(self, root: Root) -> None: Examples: >>> import asyncio >>> from mcpgateway.services.root_service import RootService - >>> from mcpgateway.models import Root + >>> from mcpgateway.common.models import Root >>> service = RootService() >>> queue = asyncio.Queue() >>> service._subscribers.append(queue) diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 888e36fb5..fb9b1c1a0 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -38,6 +38,9 @@ from sqlalchemy.orm import Session # First-Party +from mcpgateway.common.models import Gateway as PydanticGateway +from mcpgateway.common.models import TextContent +from mcpgateway.common.models import Tool as PydanticTool from mcpgateway.config import settings from mcpgateway.db import A2AAgent as DbA2AAgent from mcpgateway.db import EmailTeam @@ -45,13 +48,11 @@ from mcpgateway.db import server_tool_association from mcpgateway.db import Tool as DbTool from mcpgateway.db import ToolMetric -from mcpgateway.models import Gateway as PydanticGateway -from mcpgateway.models import TextContent -from mcpgateway.models import Tool as PydanticTool -from mcpgateway.models import ToolResult -from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, PluginError, PluginManager, PluginViolationError, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.plugins.framework import GlobalContext, PluginError, PluginManager, PluginViolationError from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA -from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer +from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload +from mcpgateway.plugins.framework.hooks.tools import ToolHookType, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.schemas import ToolCreate, ToolRead, ToolResult, ToolUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager from mcpgateway.services.observability_service import current_trace_id, ObservabilityService @@ -444,7 +445,7 @@ def _extract_and_validate_structured_content(self, tool: DbTool, tool_result: "T Examples: >>> from mcpgateway.services.tool_service import ToolService - >>> from mcpgateway.models import TextContent, ToolResult + >>> from mcpgateway.common.models import TextContent, ToolResult >>> import json >>> service = ToolService() >>> # No schema declared -> nothing to validate @@ -1195,8 +1196,9 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r if self._plugin_manager: tool_metadata = PydanticTool.model_validate(tool) global_context.metadata[TOOL_METADATA] = tool_metadata - pre_result, context_table = await self._plugin_manager.tool_pre_invoke( - payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), + pre_result, context_table = await self._plugin_manager.invoke_hook( + ToolHookType.TOOL_PRE_INVOKE, + payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), global_context=global_context, local_contexts=None, violations_as_exceptions=True, @@ -1350,8 +1352,9 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head if tool_gateway: gateway_metadata = PydanticGateway.model_validate(tool_gateway) global_context.metadata[GATEWAY_METADATA] = gateway_metadata - pre_result, context_table = await self._plugin_manager.tool_pre_invoke( - payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), + pre_result, context_table = await self._plugin_manager.invoke_hook( + ToolHookType.TOOL_PRE_INVOKE, + payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), global_context=global_context, local_contexts=None, violations_as_exceptions=True, @@ -1382,7 +1385,8 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head # Plugin hook: tool post-invoke if self._plugin_manager: - post_result, _ = await self._plugin_manager.tool_post_invoke( + post_result, _ = await self._plugin_manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)), global_context=global_context, local_contexts=context_table, diff --git a/mcpgateway/transports/streamablehttp_transport.py b/mcpgateway/transports/streamablehttp_transport.py index f52d8faf3..599f20f03 100644 --- a/mcpgateway/transports/streamablehttp_transport.py +++ b/mcpgateway/transports/streamablehttp_transport.py @@ -53,9 +53,9 @@ from starlette.types import Receive, Scope, Send # First-Party +from mcpgateway.common.models import LogLevel from mcpgateway.config import settings from mcpgateway.db import SessionLocal -from mcpgateway.models import LogLevel from mcpgateway.services.completion_service import CompletionService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.prompt_service import PromptService diff --git a/mcpgateway/utils/pagination.py b/mcpgateway/utils/pagination.py index 38816b813..b0d9ee2cd 100644 --- a/mcpgateway/utils/pagination.py +++ b/mcpgateway/utils/pagination.py @@ -22,7 +22,7 @@ from mcpgateway.utils.pagination import paginate_query from sqlalchemy import select - from mcpgateway.models import Tool + from mcpgateway.common.models import Tool async def list_tools(db: Session): query = select(Tool).where(Tool.enabled == True) @@ -312,7 +312,7 @@ async def offset_paginate( from mcpgateway.utils.pagination import offset_paginate from sqlalchemy import select - from mcpgateway.models import Tool + from mcpgateway.common.models import Tool async def list_tools_offset(db: Session, page: int = 1): query = select(Tool).where(Tool.enabled == True) @@ -411,7 +411,7 @@ async def cursor_paginate( from mcpgateway.utils.pagination import cursor_paginate from sqlalchemy import select - from mcpgateway.models import Tool + from mcpgateway.common.models import Tool async def list_tools_cursor(db: Session, cursor: Optional[str] = None): query = select(Tool).order_by(Tool.created_at.desc()) @@ -533,7 +533,7 @@ async def paginate_query( from mcpgateway.utils.pagination import paginate_query from sqlalchemy import select - from mcpgateway.models import Tool + from mcpgateway.common.models import Tool async def list_tools_auto(db: Session, page: int = 1): query = select(Tool) diff --git a/mcpgateway/utils/passthrough_headers.py b/mcpgateway/utils/passthrough_headers.py index c3f7c1f91..a260dc0b0 100644 --- a/mcpgateway/utils/passthrough_headers.py +++ b/mcpgateway/utils/passthrough_headers.py @@ -350,7 +350,7 @@ async def set_global_passthrough_headers(db: Session) -> None: Config already exists (no DB write): >>> import pytest >>> from unittest.mock import Mock, patch - >>> from mcpgateway.models import GlobalConfig + >>> from mcpgateway.common.models import GlobalConfig >>> @pytest.mark.asyncio ... @patch("mcpgateway.utils.passthrough_headers.settings") ... async def test_existing_config(mock_settings): diff --git a/mcpgateway/validators.py b/mcpgateway/validators.py index 743cf489f..45b0259dc 100644 --- a/mcpgateway/validators.py +++ b/mcpgateway/validators.py @@ -5,37 +5,14 @@ Authors: Mihai Criveti, Madhav Kandukuri SecurityValidator for MCP Gateway -This module defines the `SecurityValidator` class, which provides centralized, configurable -validation logic for user-generated content in MCP-based applications. +This module re-exports the SecurityValidator class from mcpgateway.common.validators +for backward compatibility. -The validator enforces strict security and structural rules across common input types such as: -- Display text (e.g., names, descriptions) -- Identifiers and tool names -- URIs and URLs -- JSON object depth -- Templates (including limited HTML/Jinja2) -- MIME types - -Key Features: -- Pattern-based validation using settings-defined regex for HTML/script safety -- Configurable max lengths and depth limits -- Whitelist-based URL scheme and MIME type validation -- Safe escaping of user-visible text fields -- Reusable static/class methods for field-level and form-level validation - -Intended to be used with Pydantic or similar schema-driven systems to validate and sanitize -user input in a consistent, centralized way. - -Dependencies: -- Standard Library: re, html, logging, urllib.parse -- First-party: `settings` from `mcpgateway.config` +The canonical location for SecurityValidator is mcpgateway.common.validators. +This module exists to maintain backward compatibility with code that imports from +mcpgateway.validators. Example usage: - SecurityValidator.validate_name("my_tool", field_name="Tool Name") - SecurityValidator.validate_url("https://example.com") - SecurityValidator.validate_json_depth({...}) - -Examples: >>> from mcpgateway.validators import SecurityValidator >>> SecurityValidator.sanitize_display_text('Test', 'test') '<b>Test</b>' @@ -47,1144 +24,9 @@ >>> SecurityValidator.validate_json_depth({'a': 1}) """ -# Standard -import html -import logging -import re -from urllib.parse import urlparse -import uuid - # First-Party -from mcpgateway.config import settings - -logger = logging.getLogger(__name__) - - -class SecurityValidator: - """Configurable validation with MCP-compliant limits""" - - # Configurable patterns (from settings) - DANGEROUS_HTML_PATTERN = ( - settings.validation_dangerous_html_pattern - ) # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' - DANGEROUS_JS_PATTERN = settings.validation_dangerous_js_pattern # Default: javascript:|vbscript:|on\w+\s*=|data:.*script - ALLOWED_URL_SCHEMES = settings.validation_allowed_url_schemes # Default: ["http://", "https://", "ws://", "wss://"] - - # Character type patterns - NAME_PATTERN = settings.validation_name_pattern # Default: ^[a-zA-Z0-9_\-\s]+$ - IDENTIFIER_PATTERN = settings.validation_identifier_pattern # Default: ^[a-zA-Z0-9_\-\.]+$ - VALIDATION_SAFE_URI_PATTERN = settings.validation_safe_uri_pattern # Default: ^[a-zA-Z0-9_\-.:/?=&%]+$ - VALIDATION_UNSAFE_URI_PATTERN = settings.validation_unsafe_uri_pattern # Default: [<>"\'\\] - TOOL_NAME_PATTERN = settings.validation_tool_name_pattern # Default: ^[a-zA-Z][a-zA-Z0-9_-]*$ - - # MCP-compliant limits (configurable) - MAX_NAME_LENGTH = settings.validation_max_name_length # Default: 255 - MAX_DESCRIPTION_LENGTH = settings.validation_max_description_length # Default: 8192 (8KB) - MAX_TEMPLATE_LENGTH = settings.validation_max_template_length # Default: 65536 - MAX_CONTENT_LENGTH = settings.validation_max_content_length # Default: 1048576 (1MB) - MAX_JSON_DEPTH = settings.validation_max_json_depth # Default: 10 - MAX_URL_LENGTH = settings.validation_max_url_length # Default: 2048 - - @classmethod - def sanitize_display_text(cls, value: str, field_name: str) -> str: - """Ensure text is safe for display in UI by escaping special characters - - Args: - value (str): Value to validate - field_name (str): Name of field being validated - - Returns: - str: Value if acceptable - - Raises: - ValueError: When input is not acceptable - - Examples: - Basic HTML escaping: - - >>> SecurityValidator.sanitize_display_text('Hello World', 'test') - 'Hello World' - >>> SecurityValidator.sanitize_display_text('Hello World', 'test') - 'Hello <b>World</b>' - - Empty/None handling: - - >>> SecurityValidator.sanitize_display_text('', 'test') - '' - >>> SecurityValidator.sanitize_display_text(None, 'test') #doctest: +SKIP - - Dangerous script patterns: - - >>> SecurityValidator.sanitize_display_text('alert();', 'test') - 'alert();' - >>> SecurityValidator.sanitize_display_text('javascript:alert(1)', 'test') - Traceback (most recent call last): - ... - ValueError: test contains script patterns that may cause display issues - - Polyglot attack patterns: - - >>> SecurityValidator.sanitize_display_text('"; alert()', 'test') - Traceback (most recent call last): - ... - ValueError: test contains potentially dangerous character sequences - >>> SecurityValidator.sanitize_display_text('-->test', 'test') - '-->test' - >>> SecurityValidator.sanitize_display_text('-->') - Traceback (most recent call last): - ... - ValueError: Template contains HTML tags that may interfere with proper display - >>> SecurityValidator.validate_template('Test ') - Traceback (most recent call last): - ... - ValueError: Template contains HTML tags that may interfere with proper display - >>> SecurityValidator.validate_template('') - Traceback (most recent call last): - ... - ValueError: Template contains HTML tags that may interfere with proper display - - Event handlers blocked: - - >>> SecurityValidator.validate_template('
Test
') - Traceback (most recent call last): - ... - ValueError: Template contains event handlers that may cause display issues - >>> SecurityValidator.validate_template('onload = "alert(1)"') - Traceback (most recent call last): - ... - ValueError: Template contains event handlers that may cause display issues - - SSTI prevention patterns: - - >>> SecurityValidator.validate_template('{{ __import__ }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{{ config }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{% import os %}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{{ 7*7 }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{{ 10/2 }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{{ 5+5 }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{{ 10-5 }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - - Other template injection patterns: - - >>> SecurityValidator.validate_template('${evil}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('#{evil}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('%{evil}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - - Length limit testing: - - >>> long_template = 'a' * 65537 - >>> SecurityValidator.validate_template(long_template) - Traceback (most recent call last): - ... - ValueError: Template exceeds maximum length of 65536 - """ - if not value: - return value - - if len(value) > cls.MAX_TEMPLATE_LENGTH: - raise ValueError(f"Template exceeds maximum length of {cls.MAX_TEMPLATE_LENGTH}") - - # Block dangerous tags but allow Jinja2 syntax {{ }} and {% %} - dangerous_tags = r"<(script|iframe|object|embed|link|meta|base|form)\b" - if re.search(dangerous_tags, value, re.IGNORECASE): - raise ValueError("Template contains HTML tags that may interfere with proper display") - - # Check for event handlers that could cause issues - if re.search(r"on\w+\s*=", value, re.IGNORECASE): - raise ValueError("Template contains event handlers that may cause display issues") - - # SSTI Prevention - block dangerous template expressions - ssti_patterns = [ - r"\{\{.*(__|\.|config|self|request|application|globals|builtins|import).*\}\}", # Jinja2 dangerous patterns - r"\{%.*(__|\.|config|self|request|application|globals|builtins|import).*%\}", # Jinja2 tags - r"\$\{.*\}", # ${} expressions - r"#\{.*\}", # #{} expressions - r"%\{.*\}", # %{} expressions - r"\{\{.*\*.*\}\}", # Math operations in templates (like {{7*7}}) - r"\{\{.*\/.*\}\}", # Division operations - r"\{\{.*\+.*\}\}", # Addition operations - r"\{\{.*\-.*\}\}", # Subtraction operations - ] - - for pattern in ssti_patterns: - if re.search(pattern, value, re.IGNORECASE): - raise ValueError("Template contains potentially dangerous expressions") - - return value - - @classmethod - def validate_url(cls, value: str, field_name: str = "URL") -> str: - """Validate URLs for allowed schemes and safe display - - Args: - value (str): Value to validate - field_name (str): Name of field being validated - - Returns: - str: Value if acceptable - - Raises: - ValueError: When input is not acceptable - - Examples: - Valid URLs: - - >>> SecurityValidator.validate_url('https://example.com') - 'https://example.com' - >>> SecurityValidator.validate_url('http://example.com') - 'http://example.com' - >>> SecurityValidator.validate_url('ws://example.com') - 'ws://example.com' - >>> SecurityValidator.validate_url('wss://example.com') - 'wss://example.com' - >>> SecurityValidator.validate_url('https://example.com:8080/path') - 'https://example.com:8080/path' - >>> SecurityValidator.validate_url('https://example.com/path?query=value') - 'https://example.com/path?query=value' - - Empty URL handling: - - >>> SecurityValidator.validate_url('') - Traceback (most recent call last): - ... - ValueError: URL cannot be empty - - Length validation: - - >>> long_url = 'https://example.com/' + 'a' * 2100 - >>> SecurityValidator.validate_url(long_url) - Traceback (most recent call last): - ... - ValueError: URL exceeds maximum length of 2048 - - Scheme validation: - - >>> SecurityValidator.validate_url('ftp://example.com') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('file:///etc/passwd') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('javascript:alert(1)') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('data:text/plain,hello') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('vbscript:alert(1)') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('about:blank') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('chrome://settings') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('mailto:test@example.com') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - - IPv6 URL blocking: - - >>> SecurityValidator.validate_url('https://[::1]:8080/') - Traceback (most recent call last): - ... - ValueError: URL contains IPv6 address which is not supported - >>> SecurityValidator.validate_url('https://[2001:db8::1]/') - Traceback (most recent call last): - ... - ValueError: URL contains IPv6 address which is not supported - - Protocol-relative URL blocking: - - >>> SecurityValidator.validate_url('//example.com/path') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - - Line break injection: - - >>> SecurityValidator.validate_url('https://example.com\\rHost: evil.com') - Traceback (most recent call last): - ... - ValueError: URL contains line breaks which are not allowed - >>> SecurityValidator.validate_url('https://example.com\\nHost: evil.com') - Traceback (most recent call last): - ... - ValueError: URL contains line breaks which are not allowed - - Space validation: - - >>> SecurityValidator.validate_url('https://exam ple.com') - Traceback (most recent call last): - ... - ValueError: URL contains spaces which are not allowed in URLs - >>> SecurityValidator.validate_url('https://example.com/path?query=hello world') - 'https://example.com/path?query=hello world' - - Malformed URLs: - - >>> SecurityValidator.validate_url('https://') - Traceback (most recent call last): - ... - ValueError: URL is not a valid URL - >>> SecurityValidator.validate_url('not-a-url') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - - Restricted IP addresses: - - >>> SecurityValidator.validate_url('https://0.0.0.0/') - Traceback (most recent call last): - ... - ValueError: URL contains invalid IP address (0.0.0.0) - >>> SecurityValidator.validate_url('https://169.254.169.254/') - Traceback (most recent call last): - ... - ValueError: URL contains restricted IP address - - Invalid port numbers: - - >>> SecurityValidator.validate_url('https://example.com:0/') - Traceback (most recent call last): - ... - ValueError: URL contains invalid port number - >>> try: - ... SecurityValidator.validate_url('https://example.com:65536/') - ... except ValueError as e: - ... 'Port out of range' in str(e) or 'invalid port' in str(e) - True - - Credentials in URL: - - >>> SecurityValidator.validate_url('https://user:pass@example.com/') - Traceback (most recent call last): - ... - ValueError: URL contains credentials which are not allowed - >>> SecurityValidator.validate_url('https://user@example.com/') - Traceback (most recent call last): - ... - ValueError: URL contains credentials which are not allowed - - XSS patterns in URLs: - - >>> SecurityValidator.validate_url('https://example.com/', 'test_field') - Traceback (most recent call last): - ... - ValueError: test_field contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'content') - Traceback (most recent call last): - ... - ValueError: content contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'data') - Traceback (most recent call last): - ... - ValueError: data contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'embed') - Traceback (most recent call last): - ... - ValueError: embed contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'style') - Traceback (most recent call last): - ... - ValueError: style contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'meta') - Traceback (most recent call last): - ... - ValueError: meta contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'base') - Traceback (most recent call last): - ... - ValueError: base contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('
', 'form') - Traceback (most recent call last): - ... - ValueError: form contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'image') - Traceback (most recent call last): - ... - ValueError: image contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'svg') - Traceback (most recent call last): - ... - ValueError: svg contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'video') - Traceback (most recent call last): - ... - ValueError: video contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'audio') - Traceback (most recent call last): - ... - ValueError: audio contains HTML tags that may cause security issues - """ - if not value: - return # Empty values are considered safe - # Check for dangerous HTML tags - if re.search(cls.DANGEROUS_HTML_PATTERN, value, re.IGNORECASE): - raise ValueError(f"{field_name} contains HTML tags that may cause security issues") - - @classmethod - def validate_json_depth( - cls, - obj: object, - max_depth: int | None = None, - current_depth: int = 0, - ) -> None: - """Validate that a JSON‑like structure does not exceed a depth limit. - - A *depth* is counted **only** when we enter a container (`dict` or - `list`). Primitive values (`str`, `int`, `bool`, `None`, etc.) do not - increase the depth, but an *empty* container still counts as one level. - - Args: - obj: Any Python object to inspect recursively. - max_depth: Maximum allowed depth (defaults to - :pyattr:`SecurityValidator.MAX_JSON_DEPTH`). - current_depth: Internal recursion counter. **Do not** set this - from user code. - - Raises: - ValueError: If the nesting level exceeds *max_depth*. - - Examples: - Simple flat dictionary – depth 1: :: - - >>> SecurityValidator.validate_json_depth({'name': 'Alice'}) - - Nested dict – depth 2: :: - - >>> SecurityValidator.validate_json_depth( - ... {'user': {'name': 'Alice'}} - ... ) - - Mixed dict/list – depth 3: :: - - >>> SecurityValidator.validate_json_depth( - ... {'users': [{'name': 'Alice', 'meta': {'age': 30}}]} - ... ) - - Exactly at the default limit (10) – allowed: :: - - >>> deep_10 = {'1': {'2': {'3': {'4': {'5': {'6': {'7': {'8': - ... {'9': {'10': 'end'}}}}}}}}}} - >>> SecurityValidator.validate_json_depth(deep_10) - - One level deeper – rejected: :: - - >>> deep_11 = {'1': {'2': {'3': {'4': {'5': {'6': {'7': {'8': - ... {'9': {'10': {'11': 'end'}}}}}}}}}}} - >>> SecurityValidator.validate_json_depth(deep_11) - Traceback (most recent call last): - ... - ValueError: JSON structure exceeds maximum depth of 10 - """ - if max_depth is None: - max_depth = cls.MAX_JSON_DEPTH - - # Only containers count toward depth; primitives are ignored - if not isinstance(obj, (dict, list)): - return - - next_depth = current_depth + 1 - if next_depth > max_depth: - raise ValueError(f"JSON structure exceeds maximum depth of {max_depth}") - - if isinstance(obj, dict): - for value in obj.values(): - cls.validate_json_depth(value, max_depth, next_depth) - else: # obj is a list - for item in obj: - cls.validate_json_depth(item, max_depth, next_depth) - - @classmethod - def validate_mime_type(cls, value: str) -> str: - """Validate MIME type format - - Args: - value (str): Value to validate - - Returns: - str: Value if acceptable - - Raises: - ValueError: When input is not acceptable - - Examples: - Empty/None handling: - - >>> SecurityValidator.validate_mime_type('') - '' - >>> SecurityValidator.validate_mime_type(None) #doctest: +SKIP - - Valid standard MIME types: - - >>> SecurityValidator.validate_mime_type('text/plain') - 'text/plain' - >>> SecurityValidator.validate_mime_type('application/json') - 'application/json' - >>> SecurityValidator.validate_mime_type('image/jpeg') - 'image/jpeg' - >>> SecurityValidator.validate_mime_type('text/html') - 'text/html' - >>> SecurityValidator.validate_mime_type('application/pdf') - 'application/pdf' - - Valid vendor-specific MIME types: - - >>> SecurityValidator.validate_mime_type('application/x-custom') - 'application/x-custom' - >>> SecurityValidator.validate_mime_type('text/x-log') - 'text/x-log' - - Valid MIME types with suffixes: - - >>> SecurityValidator.validate_mime_type('application/vnd.api+json') - 'application/vnd.api+json' - >>> SecurityValidator.validate_mime_type('image/svg+xml') - 'image/svg+xml' - - Invalid MIME type formats: - - >>> SecurityValidator.validate_mime_type('invalid') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('text/') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('/plain') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('text//plain') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('text/plain/extra') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('text plain') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - - Disallowed MIME types (not in whitelist - line 620): - - >>> try: - ... SecurityValidator.validate_mime_type('application/evil') - ... except ValueError as e: - ... 'not in the allowed list' in str(e) - True - >>> try: - ... SecurityValidator.validate_mime_type('text/evil') - ... except ValueError as e: - ... 'not in the allowed list' in str(e) - True - - Test MIME type with parameters (line 618): - - >>> try: - ... SecurityValidator.validate_mime_type('application/evil; charset=utf-8') - ... except ValueError as e: - ... 'Invalid MIME type format' in str(e) - True - """ - if not value: - return value - - # Basic MIME type pattern - mime_pattern = r"^[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*\/[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*$" - if not re.match(mime_pattern, value): - raise ValueError("Invalid MIME type format") - - # Common safe MIME types - safe_mime_types = settings.validation_allowed_mime_types - if value not in safe_mime_types: - # Allow x- vendor types and + suffixes - base_type = value.split(";")[0].strip() - if not (base_type.startswith("application/x-") or base_type.startswith("text/x-") or "+" in base_type): - raise ValueError(f"MIME type '{value}' is not in the allowed list") +# Re-export SecurityValidator from canonical location +# pylint: disable=unused-import +from mcpgateway.common.validators import SecurityValidator # noqa: F401 - return value +__all__ = ["SecurityValidator"] diff --git a/plugin_templates/external/tests/test_all.py b/plugin_templates/external/tests/test_all.py index 39987cbe7..b439b5136 100644 --- a/plugin_templates/external/tests/test_all.py +++ b/plugin_templates/external/tests/test_all.py @@ -8,7 +8,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework import ( GlobalContext, PluginManager, diff --git a/plugins/README.md b/plugins/README.md index e3d2cc3d5..7dc19ba41 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -43,9 +43,11 @@ Plugins can implement hooks at these lifecycle points: | `prompt_pre_fetch` | Before prompt template retrieval | `PromptPrehookPayload` | Input validation, access control | | `prompt_post_fetch` | After prompt template retrieval | `PromptPosthookPayload` | Content filtering, transformation | | `tool_pre_invoke` | Before tool execution | `ToolPreInvokePayload` | Parameter validation, safety checks | -| `tool_post_invoke` | After tool execution | `ToolPostInvokeResult` | Result filtering, audit logging | +| `tool_post_invoke` | After tool execution | `ToolPostInvokePayload` | Result filtering, audit logging | | `resource_pre_fetch` | Before resource retrieval | `ResourcePreFetchPayload` | Protocol/domain validation | -| `resource_post_fetch` | After resource retrieval | `ResourcePostFetchResult` | Content scanning, size limits | +| `resource_post_fetch` | After resource retrieval | `ResourcePostFetchPayload` | Content scanning, size limits | +| `agent_pre_invoke` | Before agent invocation | `AgentPreInvokePayload` | Message filtering, access control | +| `agent_post_invoke` | After agent response | `AgentPostInvokePayload` | Response filtering, audit logging | Future hooks (in development): - `server_pre_register` / `server_post_register` - Virtual server verification @@ -159,80 +161,279 @@ Validate and filter resource requests: ## Writing Custom Plugins -### 1. Plugin Structure +### Understanding the Plugin Base Class -Create a new directory under `plugins/`: +The `Plugin` class is an abstract base class (ABC) that provides the foundation for all plugins. You **must** subclass it and implement at least one hook method to create a functional plugin. -``` -plugins/my_plugin/ -├── __init__.py -├── plugin-manifest.yaml -├── my_plugin.py -└── README.md +```python +from abc import ABC +from mcpgateway.plugins.framework import Plugin + +class MyPlugin(Plugin): + """Your plugin must inherit from Plugin.""" + # Implement hook methods (see patterns below) ``` -### 2. Plugin Manifest (`plugin-manifest.yaml`) +### Three Hook Registration Patterns -```yaml -description: "My custom plugin" -author: "Your Name" -version: "1.0.0" -available_hooks: - - "tool_pre_invoke" - - "tool_post_invoke" -default_configs: - my_setting: true - threshold: 0.8 -``` +The plugin framework supports three flexible patterns for registering hook methods: + +#### Pattern 1: Convention-Based (Recommended for Standard Hooks) -### 3. Plugin Implementation +The simplest approach - just name your method to match the hook type: ```python -# my_plugin.py -from mcpgateway.plugins.framework.base import Plugin -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( + Plugin, + PluginContext, ToolPreInvokePayload, ToolPreInvokeResult, - PluginResult ) class MyPlugin(Plugin): - """Custom plugin implementation.""" + """Convention-based hook - method name matches hook type.""" + + async def tool_pre_invoke( + self, + payload: ToolPreInvokePayload, + context: PluginContext + ) -> ToolPreInvokeResult: + """This hook is automatically discovered by its name.""" + + # Your logic here + modified_args = {**payload.args, "processed": True} + + modified_payload = ToolPreInvokePayload( + name=payload.name, + args=modified_args, + headers=payload.headers + ) + + return ToolPreInvokeResult( + modified_payload=modified_payload, + metadata={"processed_by": self.name} + ) +``` + +**When to use:** Default choice for implementing standard framework hooks. + +#### Pattern 2: Decorator-Based (Custom Method Names) + +Use the `@hook` decorator to register a hook with a custom method name: + +```python +from mcpgateway.plugins.framework import Plugin, PluginContext +from mcpgateway.plugins.framework.decorator import hook +from mcpgateway.plugins.framework import ( + ToolHookType, + ToolPostInvokePayload, + ToolPostInvokeResult, +) + +class MyPlugin(Plugin): + """Decorator-based hook with custom method name.""" + + @hook(ToolHookType.TOOL_POST_INVOKE) + async def my_custom_handler_name( + self, + payload: ToolPostInvokePayload, + context: PluginContext + ) -> ToolPostInvokeResult: + """Method name doesn't match hook type, but @hook decorator registers it.""" + + # Your logic here + return ToolPostInvokeResult(continue_processing=True) +``` + +**When to use:** When you want descriptive method names that better match your plugin's purpose. - async def tool_pre_invoke(self, payload: ToolPreInvokePayload) -> ToolPreInvokeResult: - """Process tool invocation before execution.""" +#### Pattern 3: Custom Hooks (Advanced) - # Get plugin configuration - my_setting = self.config.get("my_setting", False) - threshold = self.config.get("threshold", 0.5) +Register completely new hook types with custom payload and result types: + +```python +from mcpgateway.plugins.framework import Plugin, PluginContext, PluginPayload, PluginResult +from mcpgateway.plugins.framework.decorator import hook - # Implement your logic - if my_setting and self._should_block(payload): - return ToolPreInvokeResult( - result=PluginResult.BLOCK, - message="Request blocked by custom logic", - modified_payload=payload +# Define custom payload type +class EmailPayload(PluginPayload): + recipient: str + subject: str + body: str + +# Define custom result type +class EmailResult(PluginResult[EmailPayload]): + pass + +class MyPlugin(Plugin): + """Custom hook with new hook type.""" + + @hook("email_pre_send", EmailPayload, EmailResult) + async def validate_email( + self, + payload: EmailPayload, + context: PluginContext + ) -> EmailResult: + """Completely new hook type: 'email_pre_send'""" + + # Validate email address + if "@" not in payload.recipient: + # Fix invalid email + modified_payload = EmailPayload( + recipient=f"{payload.recipient}@example.com", + subject=payload.subject, + body=payload.body + ) + return EmailResult( + modified_payload=modified_payload, + metadata={"fixed_email": True} ) - # Modify payload if needed - modified_payload = self._transform_payload(payload) + return EmailResult(continue_processing=True) +``` + +**When to use:** When extending the framework with domain-specific hook points not covered by standard hooks. + +### Hook Method Signature Requirements + +All hook methods must follow these rules: + +1. **Must be async**: All hooks are asynchronous +2. **Three parameters**: `self`, `payload`, `context` +3. **Type hints required** (for validation): Payload and result types must be properly typed +4. **Return appropriate result type**: Each hook returns a `PluginResult` typed with the hook's payload type + +```python +async def hook_name( + self, + payload: PayloadType, # Specific to the hook (e.g., ToolPreInvokePayload) + context: PluginContext # Always PluginContext +) -> PluginResult[PayloadType]: # PluginResult generic, parameterized by the payload type + """Hook implementation.""" + pass +``` + +**Understanding Result Types:** + +Each hook has a corresponding result type that is actually a type alias for `PluginResult[PayloadType]`: + +```python +# These are type aliases defined in the framework +ToolPreInvokeResult = PluginResult[ToolPreInvokePayload] +ToolPostInvokeResult = PluginResult[ToolPostInvokePayload] +PromptPrehookResult = PluginResult[PromptPrehookPayload] +# ... and so on for each hook type +``` + +This means when you return a result, you're returning a `PluginResult` instance that knows about the specific payload type: + +```python +# All of these are valid ways to construct results: +return ToolPreInvokeResult(continue_processing=True) +return ToolPreInvokeResult(modified_payload=new_payload) +return ToolPreInvokeResult( + modified_payload=new_payload, + metadata={"processed": True} +) +``` + +### Complete Plugin Example + +Here's a complete plugin showing all patterns: + +```python +# plugins/my_plugin/my_plugin.py +from mcpgateway.plugins.framework import ( + Plugin, + PluginContext, + PluginPayload, + PluginResult, + ToolPreInvokePayload, + ToolPreInvokeResult, + ToolPostInvokePayload, + ToolPostInvokeResult, + ToolHookType, +) +from mcpgateway.plugins.framework.decorator import hook + +class MyPlugin(Plugin): + """Example plugin demonstrating all three patterns.""" + + # Pattern 1: Convention-based + async def tool_pre_invoke( + self, + payload: ToolPreInvokePayload, + context: PluginContext + ) -> ToolPreInvokeResult: + """Pre-process tool invocation - found by naming convention.""" + + # Access plugin configuration + threshold = self.config.config.get("threshold", 0.5) + + # Modify payload + modified_args = {**payload.args, "plugin_processed": True} + modified_payload = ToolPreInvokePayload( + name=payload.name, + args=modified_args, + headers=payload.headers + ) return ToolPreInvokeResult( - result=PluginResult.CONTINUE, - modified_payload=modified_payload + modified_payload=modified_payload, + metadata={"threshold": threshold} ) - def _should_block(self, payload: ToolPreInvokePayload) -> bool: - """Custom blocking logic.""" - # Implement your validation logic here - return False + # Pattern 2: Decorator with custom name + @hook(ToolHookType.TOOL_POST_INVOKE) + async def process_tool_result( + self, + payload: ToolPostInvokePayload, + context: PluginContext + ) -> ToolPostInvokeResult: + """Post-process tool result - found via decorator.""" + + # Transform result + if isinstance(payload.result, dict): + modified_result = { + **payload.result, + "processed_by": self.name + } + modified_payload = ToolPostInvokePayload( + name=payload.name, + result=modified_result + ) + return ToolPostInvokeResult(modified_payload=modified_payload) - def _transform_payload(self, payload: ToolPreInvokePayload) -> ToolPreInvokePayload: - """Transform payload if needed.""" - return payload + return ToolPostInvokeResult(continue_processing=True) ``` -### 4. Register Your Plugin +### Plugin Structure + +Create a new directory under `plugins/`: + +``` +plugins/my_plugin/ +├── __init__.py +├── plugin-manifest.yaml +├── my_plugin.py +└── README.md +``` + +### Plugin Manifest (`plugin-manifest.yaml`) + +```yaml +description: "My custom plugin" +author: "Your Name" +version: "1.0.0" +available_hooks: + - "tool_pre_invoke" + - "tool_post_invoke" +default_configs: + threshold: 0.8 + enable_logging: true +``` + +### Register Your Plugin Add to `plugins/config.yaml`: @@ -243,34 +444,88 @@ plugins: description: "My custom plugin description" version: "1.0.0" author: "Your Name" - hooks: ["tool_pre_invoke"] + hooks: ["tool_pre_invoke", "tool_post_invoke"] mode: "enforce" priority: 100 config: - my_setting: true threshold: 0.8 + enable_logging: true ``` ## Plugin Development Best Practices +### Hook Results and Control Flow + +Each hook returns a result object that controls execution flow: + +```python +# Allow processing to continue +return ToolPreInvokeResult(continue_processing=True) + +# Modify the payload +return ToolPreInvokeResult( + modified_payload=modified_payload, + metadata={"processed": True} +) + +# Block execution with a violation +from mcpgateway.plugins.framework import PluginViolation + +return ToolPreInvokeResult( + continue_processing=False, + violation=PluginViolation( + code="POLICY_VIOLATION", + reason="Request blocked by security policy", + description="Detected prohibited content" + ) +) +``` + ### Error Handling -Errors inside a plugin should be raised as exceptions. The plugin manager will catch the error, and its behavior depends on both the gateway's and plugin's configuration as follows: +Errors inside a plugin should be raised as exceptions. The plugin manager will catch the error, and its behavior depends on both the gateway's and plugin's configuration as follows: + +1. If `plugin_settings.fail_on_plugin_error` in the plugin `config.yaml` is set to `true`, the exception is bubbled up as a PluginError and the error is passed to the client of ContextForge regardless of the plugin mode. +2. If `plugin_settings.fail_on_plugin_error` is set to false, the error is handled based off of the plugin mode in the plugin's config as follows: + * If `mode` is `enforce`, both violations and errors are bubbled up as exceptions and the execution is blocked. + * If `mode` is `enforce_ignore_error`, violations are bubbled up as exceptions and execution is blocked, but errors are logged and execution continues. + * If `mode` is `permissive`, execution is allowed to proceed whether there are errors or violations. Both are logged. -1. if `plugin_settings.fail_on_plugin_error` in the plugin `config.yaml` is set to `true` the exception is bubbled up as a PluginError and the error is passed to the client of ContextForge regardless of the plugin mode. -2. if `plugin_settings.fail_on_plugin_error` is set to false the error is handled based off of the plugin mode in the plugin's config as follows: - * if `mode` is `enforce`, both violations and errors are bubbled up as exceptions and the execution is blocked. - * if `mode` is `enforce_ignore_error`, violations are bubbled up as exceptions and execution is blocked, but errors are logged and execution continues. - * if `mode` is `permissive`, execution is allowed to proceed whether there are errors or violations. Both are logged. +### Accessing Plugin Context + +The `context` parameter provides access to request-scoped and global state: + +```python +async def tool_pre_invoke( + self, + payload: ToolPreInvokePayload, + context: PluginContext +) -> ToolPreInvokeResult: + # Access request ID + request_id = context.global_context.request_id + + # Access user information + user = context.global_context.user + tenant_id = context.global_context.tenant_id + + # Store plugin-specific state (persists across pre/post hooks) + context.state["invocation_count"] = context.state.get("invocation_count", 0) + 1 + + # Add metadata + context.metadata["processing_time"] = 0.123 + + return ToolPreInvokeResult(continue_processing=True) +``` ### Logging and Monitoring + ```python def __init__(self, config: PluginConfig): super().__init__(config) self.logger.info(f"Initialized {self.name} v{self.version}") -async def tool_pre_invoke(self, payload: ToolPreInvokePayload) -> ToolPreInvokeResult: - self.logger.debug(f"Processing tool: {payload.tool_name}") +async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: + self.logger.debug(f"Processing tool: {payload.name}") # ... plugin logic self.metrics.increment("requests_processed") ``` @@ -278,14 +533,19 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload) -> ToolPreInvokeR ### Configuration Validation ```python -def validate_config(self) -> None: +def __init__(self, config: PluginConfig): + super().__init__(config) + self._validate_config() + +def _validate_config(self) -> None: """Validate plugin configuration.""" required_keys = ["threshold", "api_key"] for key in required_keys: - if key not in self.config: + if key not in self.config.config: raise ValueError(f"Missing required config key: {key}") - if not 0 <= self.config["threshold"] <= 1: + threshold = self.config.config.get("threshold") + if not 0 <= threshold <= 1: raise ValueError("threshold must be between 0 and 1") ``` @@ -304,11 +564,12 @@ class MyPlugin(Plugin): super().__init__(config) self._session = None - async def __aenter__(self): + async def initialize(self): + """Called when plugin is loaded.""" self._session = aiohttp.ClientSession() - return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def shutdown(self): + """Called when plugin manager shuts down.""" if self._session: await self._session.close() ``` @@ -316,30 +577,48 @@ class MyPlugin(Plugin): ## Testing Plugins ### Unit Testing + ```python import pytest -from mcpgateway.plugins.framework.models import ToolPreInvokePayload, PluginConfig +from mcpgateway.plugins.framework import ( + PluginConfig, + PluginContext, + GlobalContext, + ToolPreInvokePayload, +) from plugins.my_plugin.my_plugin import MyPlugin @pytest.fixture def plugin(): config = PluginConfig( name="test_plugin", - config={"my_setting": True} + description="Test", + version="1.0", + author="Test", + kind="plugins.my_plugin.my_plugin.MyPlugin", + hooks=["tool_pre_invoke"], + config={"threshold": 0.8} ) return MyPlugin(config) +@pytest.mark.asyncio async def test_tool_pre_invoke(plugin): payload = ToolPreInvokePayload( - tool_name="test_tool", - arguments={"arg1": "value1"} + name="test_tool", + args={"arg1": "value1"} + ) + context = PluginContext( + global_context=GlobalContext(request_id="test-123") ) - result = await plugin.tool_pre_invoke(payload) - assert result.result == PluginResult.CONTINUE + result = await plugin.tool_pre_invoke(payload, context) + + assert result.continue_processing is True + assert result.modified_payload.args["plugin_processed"] is True ``` ### Integration Testing + ```bash # Test with live gateway make dev @@ -356,20 +635,39 @@ curl -X POST http://localhost:4444/tools/invoke \ 2. **Configuration errors**: Validate YAML syntax and required fields 3. **Performance issues**: Profile plugin execution time and optimize bottlenecks 4. **Hook not triggering**: Verify hook name matches available hooks in manifest +5. **Method signature errors**: Ensure hooks have correct parameters (self, payload, context) and are async ### Debug Mode + ```bash LOG_LEVEL=DEBUG make serve # port 4444 # Or with reloading dev server: LOG_LEVEL=DEBUG make dev # port 8000 ``` +### Testing Hook Discovery + +To verify your hooks are properly registered: + +```python +from mcpgateway.plugins.framework import PluginManager + +manager = PluginManager("path/to/config.yaml") +await manager.initialize() + +# Check loaded plugins +for plugin_config in manager.config.plugins: + print(f"Plugin: {plugin_config.name}") + print(f" Hooks: {plugin_config.hooks}") +``` + ## Documentation Links - **Plugin Usage Guide**: https://ibm.github.io/mcp-context-forge/using/plugins/ - **Plugin Lifecycle**: https://ibm.github.io/mcp-context-forge/using/plugins/lifecycle/ - **API Reference**: Generated from code docstrings - **Examples**: See `plugins/` directory for complete implementations +- **Hook Patterns Test**: `tests/unit/mcpgateway/plugins/framework/hooks/test_hook_patterns.py` ## Performance Metrics @@ -387,3 +685,4 @@ The framework supports high-performance operations: - Error isolation between plugins - Comprehensive audit logging - Plugin configuration validation +- Hook signature validation at plugin load time diff --git a/plugins/argument_normalizer/argument_normalizer.py b/plugins/argument_normalizer/argument_normalizer.py index b25732e25..e06eb5df8 100644 --- a/plugins/argument_normalizer/argument_normalizer.py +++ b/plugins/argument_normalizer/argument_normalizer.py @@ -145,7 +145,15 @@ class EffectiveCfg: def _merge_overrides(base: ArgumentNormalizerConfig, path: str) -> EffectiveCfg: - """Compute an effective configuration for a given field path.""" + """Compute an effective configuration for a given field path. + + Args: + base: Base configuration to start from. + path: Field path to compute configuration for. + + Returns: + Effective configuration for the given field path. + """ cfg = base # Start with base values eff = EffectiveCfg( @@ -444,6 +452,13 @@ def repl(m: re.Match[str]) -> str: def _normalize_text(text: str, eff: EffectiveCfg) -> str: """Normalize a text value using an effective configuration. + Args: + text: Text value to normalize. + eff: Effective configuration to use for normalization. + + Returns: + Normalized text value. + Examples: Normalize unicode and whitespace: diff --git a/plugins/config.yaml b/plugins/config.yaml index bc67a5d6f..7c821daf6 100644 --- a/plugins/config.yaml +++ b/plugins/config.yaml @@ -179,11 +179,11 @@ plugins: priority: 119 conditions: [] config: - allowed_tags: ["a","p","div","span","strong","em","code","pre","ul","ol","li","h1","h2","h3","h4","h5","h6","blockquote","img","br","hr","table","thead","tbody","tr","th","td"] + allowed_tags: ["a", "p", "div", "span", "strong", "em", "code", "pre", "ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6", "blockquote", "img", "br", "hr", "table", "thead", "tbody", "tr", "th", "td"] allowed_attrs: - "*": ["id","class","title","alt"] - a: ["href","rel","target"] - img: ["src","width","height","alt","title"] + "*": ["id", "class", "title", "alt"] + a: ["href", "rel", "target"] + img: ["src", "width", "height", "alt", "title"] remove_comments: true drop_unknown_tags: true strip_event_handlers: true diff --git a/plugins/content_moderation/content_moderation.py b/plugins/content_moderation/content_moderation.py index 18f9dbbbc..877654aee 100644 --- a/plugins/content_moderation/content_moderation.py +++ b/plugins/content_moderation/content_moderation.py @@ -189,7 +189,15 @@ def __init__(self, config: PluginConfig) -> None: self._cache: Dict[str, ModerationResult] = {} if self._cfg.enable_caching else None async def _get_cache_key(self, text: str, provider: ModerationProvider) -> str: - """Generate cache key for content.""" + """Generate cache key for content. + + Args: + text: Content text to generate key for. + provider: Moderation provider being used. + + Returns: + Cache key string. + """ # Standard import hashlib @@ -197,7 +205,15 @@ async def _get_cache_key(self, text: str, provider: ModerationProvider) -> str: return f"{provider.value}:{content_hash}" async def _get_cached_result(self, text: str, provider: ModerationProvider) -> Optional[ModerationResult]: - """Get cached moderation result.""" + """Get cached moderation result. + + Args: + text: Content text to check cache for. + provider: Moderation provider being used. + + Returns: + Cached moderation result if available, None otherwise. + """ if not self._cfg.enable_caching or not self._cache: return None @@ -205,7 +221,13 @@ async def _get_cached_result(self, text: str, provider: ModerationProvider) -> O return self._cache.get(cache_key) async def _cache_result(self, text: str, provider: ModerationProvider, result: ModerationResult) -> None: - """Cache moderation result.""" + """Cache moderation result. + + Args: + text: Content text being cached. + provider: Moderation provider being used. + result: Moderation result to cache. + """ if not self._cfg.enable_caching or not self._cache: return @@ -213,7 +235,18 @@ async def _cache_result(self, text: str, provider: ModerationProvider, result: M self._cache[cache_key] = result async def _moderate_with_ibm_watson(self, text: str) -> ModerationResult: - """Moderate content using IBM Watson Natural Language Understanding.""" + """Moderate content using IBM Watson Natural Language Understanding. + + Args: + text: Content text to moderate. + + Returns: + Moderation result from IBM Watson. + + Raises: + ValueError: If IBM Watson configuration not provided. + Exception: If API call fails. + """ if not self._cfg.ibm_watson: raise ValueError("IBM Watson configuration not provided") @@ -284,7 +317,18 @@ async def _moderate_with_ibm_watson(self, text: str) -> ModerationResult: raise async def _moderate_with_ibm_granite(self, text: str) -> ModerationResult: - """Moderate content using IBM Granite Guardian via Ollama.""" + """Moderate content using IBM Granite Guardian via Ollama. + + Args: + text: Content text to moderate. + + Returns: + Moderation result from IBM Granite. + + Raises: + ValueError: If IBM Granite configuration not provided. + Exception: If API call fails. + """ if not self._cfg.ibm_granite: raise ValueError("IBM Granite configuration not provided") @@ -351,7 +395,18 @@ async def _moderate_with_ibm_granite(self, text: str) -> ModerationResult: raise async def _moderate_with_openai(self, text: str) -> ModerationResult: - """Moderate content using OpenAI Moderation API.""" + """Moderate content using OpenAI Moderation API. + + Args: + text: Content text to moderate. + + Returns: + Moderation result from OpenAI. + + Raises: + ValueError: If OpenAI configuration not provided. + Exception: If API call fails. + """ if not self._cfg.openai: raise ValueError("OpenAI configuration not provided") @@ -413,7 +468,15 @@ async def _moderate_with_openai(self, text: str) -> ModerationResult: raise async def _apply_moderation_action(self, text: str, result: ModerationResult) -> str: - """Apply the moderation action to the text.""" + """Apply the moderation action to the text. + + Args: + text: Original content text. + result: Moderation result with action to apply. + + Returns: + Modified text based on moderation action. + """ if result.action == ModerationAction.BLOCK: return "" # Empty content elif result.action == ModerationAction.REDACT: @@ -432,7 +495,14 @@ async def _apply_moderation_action(self, text: str, result: ModerationResult) -> return text # Return original text async def _moderate_content(self, text: str) -> ModerationResult: - """Moderate content using the configured provider.""" + """Moderate content using the configured provider. + + Args: + text: Content text to moderate. + + Returns: + Moderation result from the configured provider. + """ if len(text) > self._cfg.max_text_length: text = text[: self._cfg.max_text_length] @@ -482,7 +552,14 @@ async def _moderate_content(self, text: str) -> ModerationResult: return result async def _moderate_with_patterns(self, text: str) -> ModerationResult: - """Fallback moderation using regex patterns.""" + """Fallback moderation using regex patterns. + + Args: + text: Content text to moderate. + + Returns: + Moderation result based on pattern matching. + """ categories = {} # Basic pattern matching for different categories @@ -532,7 +609,14 @@ async def _moderate_with_patterns(self, text: str) -> ModerationResult: ) async def _extract_text_content(self, payload: Any) -> List[str]: - """Extract text content from various payload types.""" + """Extract text content from various payload types. + + Args: + payload: Payload to extract text from. + + Returns: + List of extracted text strings. + """ texts = [] if hasattr(payload, "args") and payload.args: @@ -551,7 +635,15 @@ async def _extract_text_content(self, payload: Any) -> List[str]: return [text for text in texts if len(text.strip()) > 3] # Filter very short texts async def prompt_pre_fetch(self, payload: PromptPrehookPayload, _context: PluginContext) -> PromptPrehookResult: - """Moderate prompt content before fetching.""" + """Moderate prompt content before fetching. + + Args: + payload: Prompt payload to moderate. + _context: Plugin context (unused). + + Returns: + Result indicating whether to continue processing. + """ texts = await self._extract_text_content(payload) for text in texts: @@ -595,7 +687,15 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, _context: Plugin return PromptPrehookResult() async def tool_pre_invoke(self, payload: ToolPreInvokePayload, _context: PluginContext) -> ToolPreInvokeResult: - """Moderate tool arguments before invocation.""" + """Moderate tool arguments before invocation. + + Args: + payload: Tool invocation payload to moderate. + _context: Plugin context (unused). + + Returns: + Result indicating whether to continue processing. + """ texts = await self._extract_text_content(payload) for text in texts: @@ -634,7 +734,15 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload, _context: PluginC return ToolPreInvokeResult(metadata={"moderation_checked": True}) async def tool_post_invoke(self, payload: ToolPostInvokePayload, _context: PluginContext) -> ToolPostInvokeResult: - """Moderate tool output after invocation.""" + """Moderate tool output after invocation. + + Args: + payload: Tool result payload to moderate. + _context: Plugin context (unused). + + Returns: + Result indicating whether to continue processing. + """ # Extract text from tool results result_text = "" if hasattr(payload.result, "content"): @@ -681,7 +789,11 @@ async def tool_post_invoke(self, payload: ToolPostInvokePayload, _context: Plugi return ToolPostInvokeResult(metadata={"output_checked": True}) async def __aenter__(self): - """Async context manager entry.""" + """Async context manager entry. + + Returns: + ContentModerationPlugin: The plugin instance. + """ return self async def __aexit__(self, _exc_type, _exc_val, _exc_tb): diff --git a/plugins/external/clamav_server/clamav_plugin.py b/plugins/external/clamav_server/clamav_plugin.py index efb1962a1..7fdce282f 100644 --- a/plugins/external/clamav_server/clamav_plugin.py +++ b/plugins/external/clamav_server/clamav_plugin.py @@ -52,8 +52,14 @@ def _has_eicar(data: bytes) -> bool: - """Has Eicar implementation.""" + """Check if data contains EICAR test virus signature. + Args: + data: Bytes to scan for EICAR signature. + + Returns: + True if EICAR signature found, False otherwise. + """ blob = data.decode("latin1", errors="ignore") return any(sig in blob for sig in EICAR_SIGNATURES) @@ -62,8 +68,11 @@ class ClamAVConfig: """ClamAVConfig implementation.""" def __init__(self, cfg: dict[str, Any] | None) -> None: - """Initialize the instance.""" + """Initialize the instance. + Args: + cfg: Configuration dictionary. + """ c = cfg or {} self.mode: str = c.get("mode", "eicar_only") # eicar_only|clamd_tcp|clamd_unix self.host: str | None = c.get("clamd_host") @@ -75,8 +84,17 @@ def __init__(self, cfg: dict[str, Any] | None) -> None: def _clamd_instream_scan_tcp(host: str, port: int, data: bytes, timeout: float) -> str: - """Clamd Instream Scan Tcp implementation.""" + """Scan data using ClamAV daemon via TCP connection. + Args: + host: ClamAV daemon host address. + port: ClamAV daemon port number. + data: Bytes to scan. + timeout: Connection timeout in seconds. + + Returns: + Scan response from ClamAV daemon. + """ # Minimal INSTREAM protocol: https://linux.die.net/man/8/clamd s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(timeout) @@ -99,8 +117,16 @@ def _clamd_instream_scan_tcp(host: str, port: int, data: bytes, timeout: float) def _clamd_instream_scan_unix(path: str, data: bytes, timeout: float) -> str: - """Clamd Instream Scan Unix implementation.""" + """Scan data using ClamAV daemon via Unix socket connection. + + Args: + path: Unix socket path. + data: Bytes to scan. + timeout: Connection timeout in seconds. + Returns: + Scan response from ClamAV daemon. + """ s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) s.settimeout(timeout) s.connect(path) @@ -123,23 +149,35 @@ class ClamAVRemotePlugin(Plugin): """External ClamAV plugin for scanning resources and content.""" def __init__(self, config: PluginConfig) -> None: - """Initialize the instance.""" + """Initialize the instance. + Args: + config: Plugin configuration. + """ super().__init__(config) self._cfg = ClamAVConfig(config.config) self._stats: dict[str, int] = {"attempted": 0, "infected": 0, "blocked": 0, "errors": 0} def _bump(self, key: str) -> None: - """Bump implementation.""" + """Increment statistics counter. + Args: + key: Statistics key to increment. + """ try: self._stats[key] = int(self._stats.get(key, 0)) + 1 except Exception: pass def _scan_bytes(self, data: bytes) -> tuple[bool, str]: - """Scan Bytes implementation.""" + """Scan bytes for malware using configured scan method. + Args: + data: Bytes to scan for malware. + + Returns: + Tuple of (infected: bool, detail: str) indicating if malware was found and scan details. + """ if len(data) > self._cfg.max_bytes: return False, "SKIPPED: too large" @@ -284,8 +322,14 @@ async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: Plugin # Recursively scan string values in tool outputs def iter_strings(obj): - """Iter Strings implementation.""" + """Recursively iterate over all string values in an object. + Args: + obj: Object to iterate over (str, dict, list, or other). + + Yields: + String values found in the object. + """ if isinstance(obj, str): yield obj elif isinstance(obj, dict): @@ -320,7 +364,11 @@ def iter_strings(obj): return ToolPostInvokeResult(metadata={"clamav": {"error": str(exc)}}) def health(self) -> dict[str, Any]: - """Return plugin health and metrics; try clamd connectivity when configured.""" + """Return plugin health and metrics; try clamd connectivity when configured. + + Returns: + Dictionary containing plugin health status and metrics. + """ status = {"mode": self._cfg.mode, "block_on_positive": self._cfg.block_on_positive, "stats": dict(self._stats)} reachable = None try: diff --git a/plugins/external/llmguard/docker-compose.yaml b/plugins/external/llmguard/docker-compose.yaml index 9cd8afbd3..1e0399f05 100644 --- a/plugins/external/llmguard/docker-compose.yaml +++ b/plugins/external/llmguard/docker-compose.yaml @@ -17,7 +17,7 @@ services: llmguardplugin: container_name: llmguardplugin - image: mcpgateway/llmguardplugin:latest # Use the local latest image. Run `make docker-prod` to build it. + image: mcpgateway/llmguardplugin:latest # Use the local latest image. Run `make docker-prod` to build it. restart: always env_file: - .env @@ -33,7 +33,7 @@ services: llmguardplugin-testing: container_name: llmguardplugin-testing - image: mcpgateway/llmguardplugin-testing:latest # Use the local latest image. Run `make docker-prod` to build it. + image: mcpgateway/llmguardplugin-testing:latest # Use the local latest image. Run `make docker-prod` to build it. env_file: - .env ports: diff --git a/plugins/external/llmguard/examples/config-all-in-one.yaml b/plugins/external/llmguard/examples/config-all-in-one.yaml index 62679b563..d5ee7666a 100644 --- a/plugins/external/llmguard/examples/config-all-in-one.yaml +++ b/plugins/external/llmguard/examples/config-all-in-one.yaml @@ -5,7 +5,7 @@ plugins: description: "A plugin for running input and output through llmguard scanners " version: "0.1" author: "ContextForge" - hooks: ["prompt_pre_fetch","prompt_post_fetch"] + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] tags: ["plugin", "guardrails", "llmguard", "pre-post", "filters", "sanitizers"] mode: "enforce" # enforce | permissive | disabled priority: 20 @@ -18,11 +18,11 @@ plugins: cache_ttl: 120 #defined in seconds input: filters: - PromptInjection: - threshold: 0.6 - use_onnx: false - policy: PromptInjection - policy_message: I'm sorry, I cannot allow this input. + PromptInjection: + threshold: 0.6 + use_onnx: false + policy: PromptInjection + policy_message: I'm sorry, I cannot allow this input. sanitizers: Anonymize: language: "en" @@ -34,7 +34,7 @@ plugins: matching_strategy: exact filters: Toxicity: - threshold: 0.5 + threshold: 0.5 policy: Toxicity policy_message: I'm sorry, I cannot allow this output. diff --git a/plugins/external/llmguard/examples/config-complex-policy.yaml b/plugins/external/llmguard/examples/config-complex-policy.yaml index 199588ec9..dad5720d1 100644 --- a/plugins/external/llmguard/examples/config-complex-policy.yaml +++ b/plugins/external/llmguard/examples/config-complex-policy.yaml @@ -61,7 +61,7 @@ plugins: output: filters: Toxicity: - threshold: 0.5 + threshold: 0.5 policy: Toxicity policy_message: I'm sorry, I cannot allow this output. diff --git a/plugins/external/llmguard/examples/config-input-output-filter.yaml b/plugins/external/llmguard/examples/config-input-output-filter.yaml index 1d5272e2f..153d843bf 100644 --- a/plugins/external/llmguard/examples/config-input-output-filter.yaml +++ b/plugins/external/llmguard/examples/config-input-output-filter.yaml @@ -42,7 +42,7 @@ plugins: output: filters: Toxicity: - threshold: 0.5 + threshold: 0.5 policy: Toxicity policy_message: I'm sorry, I cannot allow this output. diff --git a/plugins/external/llmguard/examples/config-separate-plugins-filters-sanitizers.yaml b/plugins/external/llmguard/examples/config-separate-plugins-filters-sanitizers.yaml index 1ce1222fd..785a34a08 100644 --- a/plugins/external/llmguard/examples/config-separate-plugins-filters-sanitizers.yaml +++ b/plugins/external/llmguard/examples/config-separate-plugins-filters-sanitizers.yaml @@ -87,7 +87,7 @@ plugins: output: filters: Toxicity: - threshold: 0.5 + threshold: 0.5 policy: Toxicity policy_message: I'm sorry, I cannot allow this output. diff --git a/plugins/external/llmguard/llmguardplugin/cache.py b/plugins/external/llmguard/llmguardplugin/cache.py index 454b6ee65..68dfa9ff4 100644 --- a/plugins/external/llmguard/llmguardplugin/cache.py +++ b/plugins/external/llmguard/llmguardplugin/cache.py @@ -40,7 +40,7 @@ def __init__(self, ttl: int = 0) -> None: """init block for cache. This initializes a redit client. Args: - ttl: Time to live in seconds for cache + ttl: Time to live in seconds for cache """ self.cache_ttl = ttl self.cache = redis.Redis(host=redis_host, port=redis_port) @@ -53,6 +53,9 @@ def update_cache(self, key: int = None, value: tuple = None) -> tuple[bool]: Args: key: The id of vault in string value: The tuples in the vault + + Returns: + tuple[bool]: A tuple containing (success_set, success_expiry) booleans. """ serialized_obj = pickle.dumps(value) logger.info(f"Update cache in cache: {key} {serialized_obj}") @@ -73,10 +76,9 @@ def retrieve_cache(self, key: int = None) -> tuple: Args: key: The id of vault in string - value: The tuples in the vault Returns: - retrieved_obj: Return the retrieved object from cache + tuple: The retrieved object from cache or None if not found. """ value = self.cache.get(key) if value: @@ -87,14 +89,10 @@ def retrieve_cache(self, key: int = None) -> tuple: logger.error(f"Cache retrieval unsuccessful for id: {key}") def delete_cache(self, key: int = None) -> None: - """Retrieves cache for a key value + """Deletes cache for a key value Args: key: The id of vault in string - value: The tuples in the vault - - Returns: - retrieved_obj: Return the retrieved object from cache """ logger.info(f"Deleting cache for key : {key}") deleted_count = self.cache.delete(key) diff --git a/plugins/external/llmguard/llmguardplugin/llmguard.py b/plugins/external/llmguard/llmguardplugin/llmguard.py index 9612b3abe..8219ff199 100644 --- a/plugins/external/llmguard/llmguardplugin/llmguard.py +++ b/plugins/external/llmguard/llmguardplugin/llmguard.py @@ -36,7 +36,11 @@ class LLMGuardBase: """ def __init__(self, config: Optional[dict[str, Any]]) -> None: - """Initialize the instance.""" + """Initialize the instance. + + Args: + config: Configuration for guardrails. + """ self.lgconfig = LLMGuardConfig.model_validate(config) self.scanners = {"input": {"sanitizers": [], "filters": []}, "output": {"sanitizers": [], "filters": []}} @@ -65,7 +69,11 @@ def _create_new_vault_on_expiry(self, vault) -> bool: return False def _create_vault(self) -> Vault: - """This function creates a new vault and sets it's creation time as it's attribute""" + """This function creates a new vault and sets it's creation time as it's attribute + + Returns: + Vault: A new vault object with creation time set. + """ logger.info("Vault creation") vault = Vault() vault.creation_time = datetime.now() @@ -76,7 +84,11 @@ def _retreive_vault(self, sanitizer_names: list = ["Anonymize"]) -> tuple[Vault, """This function is responsible for retrieving vault for given sanitizer names Args: - sanitizer_names: list of names for sanitizers""" + sanitizer_names: list of names for sanitizers + + Returns: + tuple[Vault, int, tuple]: A tuple containing the vault object, vault ID, and vault tuples. + """ vault_id = None vault_tuples = None length = len(self.scanners["input"]["sanitizers"]) @@ -112,7 +124,9 @@ def _update_output_sanitizers(self, config, sanitizer_names: list = ["Deanonymiz """This function is responsible for updating vault for given sanitizer names in output Args: - sanitizer_names: list of names for sanitizers""" + config: Configuration containing sanitizer settings. + sanitizer_names: list of names for sanitizers + """ length = len(self.scanners["output"]["sanitizers"]) for i in range(length): scanner_name = type(self.scanners["output"]["sanitizers"][i]).__name__ @@ -131,7 +145,7 @@ def _load_policy_scanners(self, config: dict = None) -> list: config: configuration for scanner Returns: - policy_filters: Either None or a list of scanners defined in the policy + list: Either None or a list of scanners defined in the policy. """ config_keys = get_policy_filters(config) if "policy" in config: @@ -259,9 +273,10 @@ def _apply_output_filters(self, original_input, model_response) -> dict[str, dic Args: original_input: The original input prompt for which model produced a response + model_response: The model's response to apply filters on Returns: - result: A dictionary with key as scanner_name which is the name of the scanner applied to the output and value as a dictionary with keys "sanitized_prompt" which is the actual prompt, + dict[str, dict[str, Any]]: A dictionary with key as scanner_name which is the name of the scanner applied to the output and value as a dictionary with keys "sanitized_prompt" which is the actual prompt, "is_valid" which is boolean that says if the prompt is valid or not based on a scanner applied and "risk_score" which gives the risk score assigned by the scanner to the prompt. """ result = {} @@ -280,10 +295,11 @@ def _apply_output_sanitizers(self, input_prompt, model_response) -> dict[str, di """Takes in model_response and applies sanitizers on it Args: - original_input: The original input prompt for which model produced a response + input_prompt: The original input prompt for which model produced a response + model_response: The model's response to apply sanitizers on Returns: - result: A dictionary with key as scanner_name which is the name of the scanner applied to the output and value as a dictionary with keys "sanitized_prompt" which is the actual prompt, + dict[str, dict[str, Any]]: A dictionary with key as scanner_name which is the name of the scanner applied to the output and value as a dictionary with keys "sanitized_prompt" which is the actual prompt, "is_valid" which is boolean that says if the prompt is valid or not based on a scanner applied and "risk_score" which gives the risk score assigned by the scanner to the prompt. """ result = scan_output(self.scanners["output"]["sanitizers"], input_prompt, model_response) diff --git a/plugins/external/llmguard/llmguardplugin/plugin.py b/plugins/external/llmguard/llmguardplugin/plugin.py index bf9d2a985..4e52fd90a 100644 --- a/plugins/external/llmguard/llmguardplugin/plugin.py +++ b/plugins/external/llmguard/llmguardplugin/plugin.py @@ -50,7 +50,10 @@ def __init__(self, config: PluginConfig) -> None: """Entry init block for plugin. Validates the configuration of plugin and initializes an instance of LLMGuardBase with the config Args: - config: the skill configuration + config: the skill configuration + + Raises: + PluginError: If the configuration is invalid for plugin initialization. """ super().__init__(config) self.lgconfig = LLMGuardConfig.model_validate(self._config.config) @@ -62,14 +65,28 @@ def __init__(self, config: PluginConfig) -> None: raise PluginError(error=PluginErrorModel(message="Invalid configuration for plugin initilialization", plugin_name=self.name)) def __verify_lgconfig(self): - """Checks if the configuration provided for plugin is valid or not. It should either have input or output key atleast""" + """Checks if the configuration provided for plugin is valid or not. It should either have input or output key atleast + + Returns: + bool: True if configuration is valid (has input or output), False otherwise. + """ return self.lgconfig.input or self.lgconfig.output def __update_context(self, context, key, value) -> dict: - """Update Context implementation.""" + """Update Context implementation. + + Args: + context: The plugin context to update. + key: The key to set in context. + value: The value to set for the key. + """ def update_context(context): - """Update Context implementation.""" + """Update Context implementation. + + Args: + context: The plugin context to update. + """ plugin_name = self.__class__.__name__ if plugin_name not in context.state[self.guardrails_context_key]: diff --git a/plugins/external/llmguard/llmguardplugin/policy.py b/plugins/external/llmguard/llmguardplugin/policy.py index db0c1fdbe..047dbee72 100644 --- a/plugins/external/llmguard/llmguardplugin/policy.py +++ b/plugins/external/llmguard/llmguardplugin/policy.py @@ -34,7 +34,10 @@ def evaluate(self, policy: str, scan_result: dict) -> Union[bool, str]: scan_result: The result of scanners applied Returns: - A union of bool (if true or false). However, if the policy expression is invalid returns string with invalid expression + Union[bool, str]: A union of bool (if true or false). However, if the policy expression is invalid returns string with invalid expression + + Raises: + ValueError: If the policy expression contains invalid operations. """ policy_variables = {key: value["is_valid"] for key, value in scan_result.items()} try: @@ -97,10 +100,9 @@ def get_policy_filters(policy_expression) -> Union[list, None]: Args: policy_expression: The expression of policy - sentence2: The second sentence Returns: - None if no policy expression is defined, else a comma separated list of filters defined in the policy + Union[list, None]: None if no policy expression is defined, else a comma separated list of filters defined in the policy """ if isinstance(policy_expression, str): pattern = r"\b(and|or|not)\b|[()]" diff --git a/plugins/external/llmguard/resources/plugins/config.yaml b/plugins/external/llmguard/resources/plugins/config.yaml index bbb7b4d64..d583c9cf7 100644 --- a/plugins/external/llmguard/resources/plugins/config.yaml +++ b/plugins/external/llmguard/resources/plugins/config.yaml @@ -5,7 +5,7 @@ plugins: description: "A plugin for running input through llmguard scanners " version: "0.1.0" author: "Shriti Priya" - hooks: ["prompt_pre_fetch","prompt_post_fetch"] + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] tags: ["plugin", "guardrails", "llmguard", "pre-post"] mode: "enforce" # enforce | permissive | disabled priority: 10 diff --git a/plugins/external/llmguard/tests/test_llmguardplugin.py b/plugins/external/llmguard/tests/test_llmguardplugin.py index 7107e5afd..6615e08ae 100644 --- a/plugins/external/llmguard/tests/test_llmguardplugin.py +++ b/plugins/external/llmguard/tests/test_llmguardplugin.py @@ -15,7 +15,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework import GlobalContext, PluginConfig, PluginContext, PromptPosthookPayload, PromptPrehookPayload diff --git a/plugins/external/opa/opapluginfilter/plugin.py b/plugins/external/opa/opapluginfilter/plugin.py index 59826a9a5..408153bed 100644 --- a/plugins/external/opa/opapluginfilter/plugin.py +++ b/plugins/external/opa/opapluginfilter/plugin.py @@ -70,8 +70,7 @@ def __init__(self, config: PluginConfig): """Entry init block for plugin. Args: - logger: logger that the skill can make use of - config: the skill configuration + config: the skill configuration """ super().__init__(config) self.opa_config = OPAConfig.model_validate(self._config.config) @@ -105,16 +104,25 @@ def _evaluate_opa_policy(self, url: str, input: OPAInput, policy_input_data_map: Args: url: The url to call opa server input: Contains the payload of input to be sent to opa server for policy evaluation. + policy_input_data_map: Mapping of policy input data keys. Returns: - True, json_response if the opa policy is allowed else false. The json response is the actual response returned by OPA server. + tuple[bool, Any]: True, json_response if the opa policy is allowed else false. The json response is the actual response returned by OPA server. If OPA server encountered any error, the return would be True (to gracefully exit) and None would be the json_response, marking an issue with the OPA server running. """ def _key(k: str, m: str) -> str: - """Key implementation.""" + """Key implementation. + + Args: + k: The key string. + m: The mapping string. + + Returns: + str: Combined key string. + """ return f"{k}.{m}" if k.split(".")[0] == "context" else k @@ -222,10 +230,6 @@ def _extract_payload_key(self, content: Any = None, key: str = None, result: dic content: The content of post hook results. key: The key for which value needs to be extracted for. result: A list of all the values for a key. - - Returns: - None - """ if isinstance(content, list): for element in content: diff --git a/plugins/external/opa/resources/plugins/config.yaml b/plugins/external/opa/resources/plugins/config.yaml index a4499b13c..c033e8498 100644 --- a/plugins/external/opa/resources/plugins/config.yaml +++ b/plugins/external/opa/resources/plugins/config.yaml @@ -4,7 +4,7 @@ plugins: description: "An OPA plugin that enforces rego policies on requests and allows/denies requests as per policies" version: "0.1.0" author: "Shriti Priya" - hooks: ["tool_pre_invoke","tool_post_invoke", "prompt_pre_fetch", "prompt_post_fetch", "resource_pre_fetch", "resource_post_fetch"] + hooks: ["tool_pre_invoke", "tool_post_invoke", "prompt_pre_fetch", "prompt_post_fetch", "resource_pre_fetch", "resource_post_fetch"] tags: ["plugin"] mode: "permissive" # enforce | permissive | disabled priority: 30 diff --git a/plugins/external/opa/tests/test_all.py b/plugins/external/opa/tests/test_all.py index 227abaebc..71cbca5c8 100644 --- a/plugins/external/opa/tests/test_all.py +++ b/plugins/external/opa/tests/test_all.py @@ -8,13 +8,13 @@ import pytest # First-Party -from mcpgateway.models import Message, ResourceContent, Role, TextContent +from mcpgateway.common.models import Message, ResourceContent, Role, TextContent from mcpgateway.plugins.framework import ( GlobalContext, PluginManager, + PluginResult, PromptPosthookPayload, PromptPrehookPayload, - PromptResult, ResourcePostFetchPayload, ResourcePreFetchPayload, ToolPostInvokePayload, @@ -24,7 +24,11 @@ @pytest.fixture(scope="module", autouse=True) def plugin_manager(): - """Initialize plugin manager.""" + """Initialize plugin manager. + + Yields: + PluginManager: An initialized plugin manager instance. + """ plugin_manager = PluginManager("./resources/plugins/config.yaml") asyncio.run(plugin_manager.initialize()) yield plugin_manager @@ -33,7 +37,11 @@ def plugin_manager(): @pytest.mark.asyncio async def test_prompt_pre_hook(plugin_manager: PluginManager): - """Test prompt pre hook across all registered plugins.""" + """Test prompt pre hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing payload = PromptPrehookPayload(name="test_prompt", args={"arg0": "This is an argument"}) global_context = GlobalContext(request_id="1") @@ -44,10 +52,14 @@ async def test_prompt_pre_hook(plugin_manager: PluginManager): @pytest.mark.asyncio async def test_prompt_post_hook(plugin_manager: PluginManager): - """Test prompt post hook across all registered plugins.""" + """Test prompt post hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing message = Message(content=TextContent(type="text", text="prompt"), role=Role.USER) - prompt_result = PromptResult(messages=[message]) + prompt_result = PluginResult(messages=[message]) payload = PromptPosthookPayload(name="test_prompt", result=prompt_result) global_context = GlobalContext(request_id="1") result, _ = await plugin_manager.prompt_post_fetch(payload, global_context) @@ -57,7 +69,11 @@ async def test_prompt_post_hook(plugin_manager: PluginManager): @pytest.mark.asyncio async def test_tool_pre_hook(plugin_manager: PluginManager): - """Test tool pre hook across all registered plugins.""" + """Test tool pre hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing payload = ToolPreInvokePayload(name="test_prompt", args={"arg0": "This is an argument"}) global_context = GlobalContext(request_id="1") @@ -68,7 +84,11 @@ async def test_tool_pre_hook(plugin_manager: PluginManager): @pytest.mark.asyncio async def test_tool_post_hook(plugin_manager: PluginManager): - """Test tool post hook across all registered plugins.""" + """Test tool post hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing payload = ToolPostInvokePayload(name="test_tool", result={"output0": "output value"}) global_context = GlobalContext(request_id="1") @@ -79,7 +99,11 @@ async def test_tool_post_hook(plugin_manager: PluginManager): @pytest.mark.asyncio async def test_resource_pre_hook(plugin_manager: PluginManager): - """Test tool post hook across all registered plugins.""" + """Test tool post hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing payload = ResourcePreFetchPayload(uri="https://test_resource.com", metadata={}) global_context = GlobalContext(request_id="1", server_id="2") @@ -90,7 +114,11 @@ async def test_resource_pre_hook(plugin_manager: PluginManager): @pytest.mark.asyncio async def test_resource_post_hook(plugin_manager: PluginManager): - """Test tool post hook across all registered plugins.""" + """Test tool post hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing content = ResourceContent( type="resource", diff --git a/plugins/external/opa/tests/test_opapluginfilter.py b/plugins/external/opa/tests/test_opapluginfilter.py index 046b5df2e..075d9e54f 100644 --- a/plugins/external/opa/tests/test_opapluginfilter.py +++ b/plugins/external/opa/tests/test_opapluginfilter.py @@ -16,14 +16,14 @@ import pytest # First-Party -from mcpgateway.models import Message, ResourceContent, Role, TextContent +from mcpgateway.common.models import Message, ResourceContent, Role, TextContent from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, + PluginResult, PromptPosthookPayload, PromptPrehookPayload, - PromptResult, ResourcePostFetchPayload, ResourcePreFetchPayload, ToolPostInvokePayload, @@ -160,7 +160,7 @@ async def test_post_prompt_fetch_opapluginfilter(): # Benign payload (allowed by OPA (rego) policy) message = Message(content=TextContent(type="text", text="abc"), role=Role.USER) - prompt_result = PromptResult(messages=[message]) + prompt_result = PluginResult(messages=[message]) payload = PromptPosthookPayload(name="test_prompt", result=prompt_result) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) result = await plugin.prompt_post_fetch(payload, context) @@ -168,7 +168,7 @@ async def test_post_prompt_fetch_opapluginfilter(): # Malign payload (denied by OPA (rego) policy) message = Message(content=TextContent(type="text", text="abc@example.com"), role=Role.USER) - prompt_result = PromptResult(messages=[message]) + prompt_result = PluginResult(messages=[message]) payload = PromptPosthookPayload(name="test_prompt", result=prompt_result) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) result = await plugin.prompt_post_fetch(payload, context) diff --git a/plugins/file_type_allowlist/file_type_allowlist.py b/plugins/file_type_allowlist/file_type_allowlist.py index d344c52f3..aa7b20143 100644 --- a/plugins/file_type_allowlist/file_type_allowlist.py +++ b/plugins/file_type_allowlist/file_type_allowlist.py @@ -20,7 +20,7 @@ from pydantic import BaseModel, Field # First-Party -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from mcpgateway.plugins.framework import ( Plugin, PluginConfig, diff --git a/plugins/html_to_markdown/html_to_markdown.py b/plugins/html_to_markdown/html_to_markdown.py index d3b92b11f..9a87dfe23 100644 --- a/plugins/html_to_markdown/html_to_markdown.py +++ b/plugins/html_to_markdown/html_to_markdown.py @@ -18,7 +18,7 @@ from typing import Any # First-Party -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from mcpgateway.plugins.framework import ( Plugin, PluginConfig, diff --git a/plugins/markdown_cleaner/markdown_cleaner.py b/plugins/markdown_cleaner/markdown_cleaner.py index be1c8e216..16d48c5f9 100644 --- a/plugins/markdown_cleaner/markdown_cleaner.py +++ b/plugins/markdown_cleaner/markdown_cleaner.py @@ -17,7 +17,7 @@ from typing import Any # First-Party -from mcpgateway.models import Message, PromptResult, ResourceContent, TextContent +from mcpgateway.common.models import Message, PromptResult, ResourceContent, TextContent from mcpgateway.plugins.framework import ( Plugin, PluginConfig, diff --git a/plugins/pii_filter/pii_filter.py b/plugins/pii_filter/pii_filter.py index 0f7215467..69609c06e 100644 --- a/plugins/pii_filter/pii_filter.py +++ b/plugins/pii_filter/pii_filter.py @@ -43,7 +43,10 @@ _RustPIIDetector = None try: - from .pii_filter_rust import RustPIIDetector as _RustPIIDetector, RUST_AVAILABLE as _RUST_AVAILABLE + # Local + from .pii_filter_rust import RUST_AVAILABLE as _RUST_AVAILABLE + from .pii_filter_rust import RustPIIDetector as _RustPIIDetector + if _RUST_AVAILABLE: logger.info("🦀 Rust PII filter available - using high-performance implementation (5-100x speedup)") else: @@ -805,6 +808,9 @@ def _apply_pii_masking_to_parsed_json(self, data: Any, base_path: str, all_detec data: The parsed JSON data structure base_path: The base path for this JSON data all_detections: Dictionary containing all PII detections + + Returns: + None: Modifies data in place. """ if isinstance(data, str): # Check if this path has detections diff --git a/plugins/pii_filter/pii_filter_rust.py b/plugins/pii_filter/pii_filter_rust.py index c0d9a34e2..180f3a5c7 100644 --- a/plugins/pii_filter/pii_filter_rust.py +++ b/plugins/pii_filter/pii_filter_rust.py @@ -9,11 +9,13 @@ Thin Python wrapper around the Rust implementation for seamless integration. """ -from typing import Dict, List, Any, TYPE_CHECKING +# Standard import logging +from typing import Any, Dict, List, TYPE_CHECKING # Use TYPE_CHECKING to avoid circular import at runtime if TYPE_CHECKING: + # Local from .pii_filter import PIIFilterConfig logger = logging.getLogger(__name__) @@ -21,20 +23,23 @@ # Try to import Rust implementation # Fix sys.path to prioritize site-packages over source directory try: - import sys + # Standard import os + import sys # Temporarily remove current directory from path if it contains plugins_rust source original_path = sys.path.copy() project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - plugins_rust_src = os.path.join(project_root, 'plugins_rust') + plugins_rust_src = os.path.join(project_root, "plugins_rust") # Remove source directory from path temporarily filtered_path = [p for p in sys.path if not p.startswith(plugins_rust_src)] sys.path = filtered_path try: + # First-Party from plugins_rust import PIIDetectorRust as _RustDetector + RUST_AVAILABLE = True logger.info("🦀 Rust PII filter module imported successfully") finally: @@ -69,16 +74,15 @@ def __init__(self, config: "PIIFilterConfig"): Raises: ImportError: If Rust implementation is not available + TypeError: If configuration type is invalid ValueError: If configuration is invalid """ # Import here to avoid circular dependency + # Local from .pii_filter import PIIFilterConfig # pylint: disable=import-outside-toplevel if not RUST_AVAILABLE: - raise ImportError( - "Rust implementation not available. " - "Install with: pip install mcpgateway[rust]" - ) + raise ImportError("Rust implementation not available. " "Install with: pip install mcpgateway[rust]") # Validate config type if not isinstance(config, PIIFilterConfig): @@ -114,6 +118,9 @@ def detect(self, text: str) -> Dict[str, List[Dict]]: ] } + Raises: + RuntimeError: If PII detection fails. + Example: >>> detector.detect("SSN: 123-45-6789") {'ssn': [{'value': '123-45-6789', 'start': 5, 'end': 16, 'mask_strategy': 'partial'}]} @@ -132,7 +139,10 @@ def mask(self, text: str, detections: Dict[str, List[Dict]]) -> str: detections: Detection results from detect() Returns: - Masked text with PII replaced according to strategies + str: Masked text with PII replaced according to strategies + + Raises: + RuntimeError: If PII masking fails. Example: >>> text = "SSN: 123-45-6789" @@ -157,11 +167,14 @@ def process_nested(self, data: Any, path: str = "") -> tuple[bool, Any, Dict]: path: Current path in the structure (for logging) Returns: - Tuple of (modified, new_data, detections) where: + tuple[bool, Any, Dict]: Tuple of (modified, new_data, detections) where: - modified: True if any PII was found and masked - new_data: The data structure with masked PII - detections: Dictionary of all detections found + Raises: + RuntimeError: If nested processing fails. + Example: >>> data = {"user": {"ssn": "123-45-6789", "name": "John"}} >>> modified, new_data, detections = detector.process_nested(data) @@ -176,4 +189,4 @@ def process_nested(self, data: Any, path: str = "") -> tuple[bool, Any, Dict]: # Export module-level availability flag -__all__ = ['RustPIIDetector', 'RUST_AVAILABLE'] +__all__ = ["RustPIIDetector", "RUST_AVAILABLE"] diff --git a/plugins/privacy_notice_injector/privacy_notice_injector.py b/plugins/privacy_notice_injector/privacy_notice_injector.py index 31e1d503e..f619dbaad 100644 --- a/plugins/privacy_notice_injector/privacy_notice_injector.py +++ b/plugins/privacy_notice_injector/privacy_notice_injector.py @@ -19,7 +19,7 @@ from pydantic import BaseModel # First-Party -from mcpgateway.models import Message, Role, TextContent +from mcpgateway.common.models import Message, Role, TextContent from mcpgateway.plugins.framework import ( Plugin, PluginConfig, diff --git a/plugins/resource_filter/resource_filter.py b/plugins/resource_filter/resource_filter.py index d5c191b35..98121db25 100644 --- a/plugins/resource_filter/resource_filter.py +++ b/plugins/resource_filter/resource_filter.py @@ -176,7 +176,7 @@ async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: if filtered_text != original_text: # Create new content object with filtered text # First-Party - from mcpgateway.models import ResourceContent + from mcpgateway.common.models import ResourceContent modified_content = ResourceContent( type=payload.content.type, diff --git a/plugins/vault/vault_plugin.py b/plugins/vault/vault_plugin.py index 994b9eaf4..115168419 100644 --- a/plugins/vault/vault_plugin.py +++ b/plugins/vault/vault_plugin.py @@ -22,13 +22,13 @@ # First-Party from mcpgateway.db import get_db from mcpgateway.plugins.framework import ( + HttpHeaderPayload, Plugin, PluginConfig, PluginContext, ToolPreInvokePayload, ToolPreInvokeResult, ) -from mcpgateway.plugins.framework.models import HttpHeaderPayload from mcpgateway.services.gateway_service import GatewayService from mcpgateway.services.logging_service import LoggingService diff --git a/plugins/virus_total_checker/virus_total_checker.py b/plugins/virus_total_checker/virus_total_checker.py index 5b10f696f..1754a86c6 100644 --- a/plugins/virus_total_checker/virus_total_checker.py +++ b/plugins/virus_total_checker/virus_total_checker.py @@ -313,7 +313,13 @@ def _ip_in_cidrs(ip: str, cidrs: list[str]) -> bool: def _apply_overrides(url: str, host: str | None, cfg: VirusTotalConfig) -> str | None: """Return 'deny', 'allow', or None based on local overrides and precedence. - Precedence order is controlled by cfg.override_precedence. + Args: + url: The URL to check for overrides. + host: The host to check for overrides (optional). + cfg: The VirusTotal configuration. + + Returns: + str | None: 'deny', 'allow', or None based on overrides. Precedence order is controlled by cfg.override_precedence. """ host_l = (host or "").lower() allow = _url_matches(url, cfg.allow_url_patterns) or (host_l and _domain_matches(host_l, cfg.allow_domains)) or (host_l and _ip_in_cidrs(host_l, cfg.allow_ip_cidrs)) diff --git a/plugins/webhook_notification/webhook_notification.py b/plugins/webhook_notification/webhook_notification.py index f76577b0c..37eeb61da 100644 --- a/plugins/webhook_notification/webhook_notification.py +++ b/plugins/webhook_notification/webhook_notification.py @@ -131,7 +131,15 @@ def __init__(self, config: PluginConfig) -> None: self._client = httpx.AsyncClient() async def _render_template(self, template: str, context: Dict[str, Any]) -> str: - """Render a Jinja2-style template with the given context.""" + """Render a Jinja2-style template with the given context. + + Args: + template: The template string to render. + context: The context dictionary for template rendering. + + Returns: + str: The rendered template string. + """ # Simple template substitution for now - could be enhanced with Jinja2 result = template for key, value in context.items(): @@ -145,7 +153,16 @@ async def _render_template(self, template: str, context: Dict[str, Any]) -> str: return result def _create_hmac_signature(self, payload: str, secret: str, algorithm: str) -> str: - """Create HMAC signature for the payload.""" + """Create HMAC signature for the payload. + + Args: + payload: The payload to sign. + secret: The secret key for HMAC. + algorithm: The hash algorithm to use. + + Returns: + str: The HMAC signature string. + """ hash_func = getattr(hashlib, algorithm, hashlib.sha256) signature = hmac.new(secret.encode("utf-8"), payload.encode("utf-8"), hash_func).hexdigest() return f"{algorithm}={signature}" @@ -159,7 +176,16 @@ async def _send_webhook( metadata: Optional[Dict[str, Any]] = None, payload_data: Optional[Dict[str, Any]] = None, ) -> None: - """Send a webhook notification with retry logic.""" + """Send a webhook notification with retry logic. + + Args: + webhook: The webhook configuration. + event: The event type to notify. + context: The plugin context. + violation: Optional violation details. + metadata: Optional metadata dictionary. + payload_data: Optional payload data dictionary. + """ if not webhook.enabled or event not in webhook.events: return @@ -229,7 +255,15 @@ async def _send_webhook( async def _notify_webhooks( self, event: EventType, context: PluginContext, violation: Optional[PluginViolation] = None, metadata: Optional[Dict[str, Any]] = None, payload_data: Optional[Dict[str, Any]] = None ) -> None: - """Send notifications to all configured webhooks.""" + """Send notifications to all configured webhooks. + + Args: + event: The event type to notify. + context: The plugin context. + violation: Optional violation details. + metadata: Optional metadata dictionary. + payload_data: Optional payload data dictionary. + """ if not self._cfg.webhooks: return @@ -240,7 +274,14 @@ async def _notify_webhooks( await asyncio.gather(*tasks, return_exceptions=True) def _determine_event_type(self, violation: Optional[PluginViolation]) -> EventType: - """Determine event type based on violation details.""" + """Determine event type based on violation details. + + Args: + violation: Optional violation details. + + Returns: + EventType: The determined event type. + """ if not violation: return EventType.TOOL_SUCCESS @@ -254,20 +295,52 @@ def _determine_event_type(self, violation: Optional[PluginViolation]) -> EventTy return EventType.VIOLATION async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - """Hook for prompt pre-fetch events.""" + """Hook for prompt pre-fetch events. + + Args: + payload: The prompt pre-hook payload. + context: The plugin context. + + Returns: + PromptPrehookResult: The result of the pre-fetch hook. + """ return PromptPrehookResult() async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: - """Hook for prompt post-fetch events.""" + """Hook for prompt post-fetch events. + + Args: + payload: The prompt post-hook payload. + context: The plugin context. + + Returns: + PromptPosthookResult: The result of the post-fetch hook. + """ await self._notify_webhooks(EventType.PROMPT_SUCCESS, context, metadata={"prompt_id": payload.prompt_id}) return PromptPosthookResult() async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: - """Hook for tool pre-invoke events.""" + """Hook for tool pre-invoke events. + + Args: + payload: The tool pre-invoke payload. + context: The plugin context. + + Returns: + ToolPreInvokeResult: The result of the pre-invoke hook. + """ return ToolPreInvokeResult() async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: - """Hook for tool post-invoke events.""" + """Hook for tool post-invoke events. + + Args: + payload: The tool post-invoke payload. + context: The plugin context. + + Returns: + ToolPostInvokeResult: The result of the post-invoke hook. + """ # Check if there was an error in the result event = EventType.TOOL_SUCCESS metadata = {"tool_name": payload.name} @@ -284,16 +357,36 @@ async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: Plugin return ToolPostInvokeResult() async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: - """Hook for resource pre-fetch events.""" + """Hook for resource pre-fetch events. + + Args: + payload: The resource pre-fetch payload. + context: The plugin context. + + Returns: + ResourcePreFetchResult: The result of the pre-fetch hook. + """ return ResourcePreFetchResult() async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: - """Hook for resource post-fetch events.""" + """Hook for resource post-fetch events. + + Args: + payload: The resource post-fetch payload. + context: The plugin context. + + Returns: + ResourcePostFetchResult: The result of the post-fetch hook. + """ await self._notify_webhooks(EventType.RESOURCE_SUCCESS, context, metadata={"resource_uri": payload.uri}) return ResourcePostFetchResult() async def __aenter__(self): - """Async context manager entry.""" + """Async context manager entry. + + Returns: + WebhookNotificationPlugin: The plugin instance. + """ return self async def __aexit__(self, _exc_type, _exc_val, _exc_tb): diff --git a/pyproject.toml b/pyproject.toml index a5decfa62..8a083f640 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -516,7 +516,6 @@ warn_unreachable = true # Warn about unreachable code warn_unused_ignores = true # Warn if a "# type: ignore" is unnecessary warn_unused_configs = true # Warn about unused config options warn_redundant_casts = true # Warn if a cast does nothing -warn_unused_coroutine = true # Warn if an unused async coroutine is defined strict_equality = true # Disallow ==/!= between incompatible types # Output formatting diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index f7cb0f997..b9b3ec299 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -32,7 +32,7 @@ # First-Party from mcpgateway.main import app, require_auth -from mcpgateway.models import InitializeResult, ResourceContent, ServerCapabilities +from mcpgateway.common.models import InitializeResult, ResourceContent, ServerCapabilities from mcpgateway.schemas import ResourceRead, ServerRead, ToolMetrics, ToolRead # Local diff --git a/tests/integration/test_resource_plugin_integration.py b/tests/integration/test_resource_plugin_integration.py index 33293b45b..ae9013503 100644 --- a/tests/integration/test_resource_plugin_integration.py +++ b/tests/integration/test_resource_plugin_integration.py @@ -18,7 +18,7 @@ # First-Party from mcpgateway.db import Base -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from mcpgateway.schemas import ResourceCreate from mcpgateway.services.resource_service import ResourceService @@ -44,9 +44,19 @@ def resource_service_with_mock_plugins(self): # Standard from unittest.mock import AsyncMock + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + mock_manager = MagicMock() mock_manager._initialized = True mock_manager.initialize = AsyncMock() + # Add default invoke_hook mock that returns success + mock_manager.invoke_hook = AsyncMock( + return_value=( + PluginResult(continue_processing=True, modified_payload=None), + None # contexts + ) + ) MockPluginManager.return_value = mock_manager service = ResourceService() service._plugin_manager = mock_manager @@ -57,20 +67,7 @@ async def test_full_resource_lifecycle_with_plugins(self, test_db, resource_serv """Test complete resource lifecycle with plugin hooks.""" service, mock_manager = resource_service_with_mock_plugins - # Configure mock plugin manager for all operations - # Standard - from unittest.mock import AsyncMock - - pre_result = MagicMock() - pre_result.continue_processing = True - pre_result.modified_payload = None - - post_result = MagicMock() - post_result.continue_processing = True - post_result.modified_payload = None - - mock_manager.resource_pre_fetch = AsyncMock(return_value=(pre_result, {"context": "data"})) - mock_manager.resource_post_fetch = AsyncMock(return_value=(post_result, None)) + # The default invoke_hook from fixture will work fine for this test # 1. Create a resource resource_data = ResourceCreate( @@ -96,8 +93,8 @@ async def test_full_resource_lifecycle_with_plugins(self, test_db, resource_serv ) assert content is not None - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_called_once() + # Verify hooks were called (pre and post fetch) + assert mock_manager.invoke_hook.call_count >= 2 # 3. List resources resources, _ = await service.list_resources(test_db) @@ -135,7 +132,7 @@ async def test_resource_filtering_integration(self, test_db): # Use real plugin manager but mock its initialization with patch("mcpgateway.services.resource_service.PluginManager") as MockPluginManager: # First-Party - from mcpgateway.plugins.framework.models import ( + from mcpgateway.plugins.framework import ( ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchResult, @@ -153,58 +150,67 @@ async def initialize(self): def initialized(self) -> bool: return self._initialized - async def resource_pre_fetch(self, payload, global_context, violations_as_exceptions): - # Allow test:// protocol - if payload.uri.startswith("test://"): + async def invoke_hook(self, hook_type, payload, global_context, local_contexts=None, **kwargs): + # First-Party + from mcpgateway.plugins.framework import ResourceHookType + + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: + # Allow test:// protocol + if payload.uri.startswith("test://"): + return ( + ResourcePreFetchResult( + continue_processing=True, + modified_payload=payload, + ), + {"validated": True}, + ) + else: + # First-Party + from mcpgateway.plugins.framework.models import PluginViolation + + raise PluginViolationError( + message="Protocol not allowed", + violation=PluginViolation( + reason="Protocol not allowed", + description="Protocol is not in the allowed list", + code="PROTOCOL_BLOCKED", + details={"protocol": payload.uri.split(":")[0], "uri": payload.uri}, + ), + ) + elif hook_type == ResourceHookType.RESOURCE_POST_FETCH: + # Filter sensitive content + if payload.content and payload.content.text: + filtered_text = payload.content.text.replace( + "password: secret123", + "password: [REDACTED]", + ) + filtered_content = ResourceContent( + id=payload.content.id, + type=payload.content.type, + uri=payload.content.uri, + text=filtered_text, + ) + modified_payload = ResourcePostFetchPayload( + uri=payload.uri, + content=filtered_content, + ) + return ( + ResourcePostFetchResult( + continue_processing=True, + modified_payload=modified_payload, + ), + None, + ) return ( - ResourcePreFetchResult( - continue_processing=True, - modified_payload=payload, - ), - {"validated": True}, + ResourcePostFetchResult(continue_processing=True), + None, ) else: + # Other hook types - just return success # First-Party - from mcpgateway.plugins.framework.models import PluginViolation - - raise PluginViolationError( - message="Protocol not allowed", - violation=PluginViolation( - reason="Protocol not allowed", - description="Protocol is not in the allowed list", - code="PROTOCOL_BLOCKED", - details={"protocol": payload.uri.split(":")[0], "uri": payload.uri}, - ), - ) + from mcpgateway.plugins.framework.models import PluginResult - async def resource_post_fetch(self, payload, global_context, contexts, violations_as_exceptions): - # Filter sensitive content - if payload.content and payload.content.text: - filtered_text = payload.content.text.replace( - "password: secret123", - "password: [REDACTED]", - ) - filtered_content = ResourceContent( - id=payload.content.id, - type=payload.content.type, - uri=payload.content.uri, - text=filtered_text, - ) - modified_payload = ResourcePostFetchPayload( - uri=payload.uri, - content=filtered_content, - ) - return ( - ResourcePostFetchResult( - continue_processing=True, - modified_payload=modified_payload, - ), - None, - ) - return ( - ResourcePostFetchResult(continue_processing=True), - None, - ) + return (PluginResult(continue_processing=True), None) MockPluginManager.return_value = MockFilterManager("test.yaml") service = ResourceService() @@ -257,29 +263,37 @@ async def test_plugin_context_flow(self, test_db, resource_service_with_mock_plu service, mock_manager = resource_service_with_mock_plugins # Track context flow + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.framework import ResourceHookType + contexts_from_pre = {"plugin_data": "test_value", "validated": True} - async def pre_fetch_side_effect(payload, global_context, violations_as_exceptions): - # Verify global context - assert global_context.request_id == "integration-test-123" - assert global_context.user == "integration-user" - assert global_context.server_id == "server-123" - return ( - MagicMock(continue_processing=True, modified_payload=None), - contexts_from_pre, - ) - - async def post_fetch_side_effect(payload, global_context, contexts, violations_as_exceptions): - # Verify contexts from pre-fetch - assert contexts == contexts_from_pre - assert contexts["plugin_data"] == "test_value" - return ( - MagicMock(continue_processing=True), - None, - ) - - mock_manager.resource_pre_fetch.side_effect = pre_fetch_side_effect - mock_manager.resource_post_fetch.side_effect = post_fetch_side_effect + async def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: + # Verify global context + assert global_context.request_id == "integration-test-123" + assert global_context.user == "integration-user" + assert global_context.server_id == "server-123" + return ( + PluginResult(continue_processing=True, modified_payload=None), + contexts_from_pre, + ) + elif hook_type == ResourceHookType.RESOURCE_POST_FETCH: + # Verify contexts from pre-fetch + assert local_contexts == contexts_from_pre + assert local_contexts["plugin_data"] == "test_value" + return ( + PluginResult(continue_processing=True), + None, + ) + else: + return (PluginResult(continue_processing=True), None) + + # Standard + from unittest.mock import AsyncMock + + mock_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) # Create and read a resource resource = ResourceCreate( @@ -297,29 +311,15 @@ async def post_fetch_side_effect(payload, global_context, contexts, violations_a server_id="server-123", ) - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_called_once() + # Verify hooks were called + assert mock_manager.invoke_hook.call_count >= 2 @pytest.mark.asyncio async def test_template_resource_with_plugins(self, test_db, resource_service_with_mock_plugins): """Test resources work with plugins using template-like content.""" service, mock_manager = resource_service_with_mock_plugins - # Configure plugin manager - # Standard - from unittest.mock import AsyncMock - - # Create proper mock results - pre_result = MagicMock() - pre_result.continue_processing = True - pre_result.modified_payload = None - - post_result = MagicMock() - post_result.continue_processing = True - post_result.modified_payload = None - - mock_manager.resource_pre_fetch = AsyncMock(return_value=(pre_result, {"context": "data"})) - mock_manager.resource_post_fetch = AsyncMock(return_value=(post_result, None)) + # The default invoke_hook from fixture will work fine # Create a regular resource with template-like content resource = ResourceCreate( @@ -332,24 +332,15 @@ async def test_template_resource_with_plugins(self, test_db, resource_service_wi content = await service.read_resource(test_db, created.id) assert content.text == "Data for ID: 123" - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_called_once() + # Verify hooks were called + assert mock_manager.invoke_hook.call_count >= 2 @pytest.mark.asyncio async def test_inactive_resource_handling(self, test_db, resource_service_with_mock_plugins): """Test that inactive resources are handled correctly with plugins.""" service, mock_manager = resource_service_with_mock_plugins - # Configure mock plugin manager - # Standard - from unittest.mock import AsyncMock - - pre_result = MagicMock() - pre_result.continue_processing = True - pre_result.modified_payload = None - - mock_manager.resource_pre_fetch = AsyncMock(return_value=(pre_result, None)) - mock_manager.resource_post_fetch = AsyncMock() + # The default invoke_hook from fixture will work fine # Create a resource resource = ResourceCreate( @@ -373,5 +364,5 @@ async def test_inactive_resource_handling(self, test_db, resource_service_with_m assert "exists but is inactive" in str(exc_info.value) # Pre-fetch is called but post-fetch should not be called for inactive resources - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_not_called() + # Only one invoke_hook call (pre-fetch) since error occurs before post-fetch + assert mock_manager.invoke_hook.call_count == 1 diff --git a/tests/unit/mcpgateway/db/test_observability_migrations.py b/tests/unit/mcpgateway/db/test_observability_migrations.py new file mode 100644 index 000000000..4e887d199 --- /dev/null +++ b/tests/unit/mcpgateway/db/test_observability_migrations.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/db/test_observability_migrations.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Unit tests for observability Alembic migrations. + +Tests verify: +- Migration modules can be imported +- Upgrade and downgrade functions exist +- Migration revision IDs are correct +- Dependencies are properly defined +- No syntax errors in migration code +- Cross-database SQL compatibility +""" + +# Standard +import importlib +import inspect as pyinspect +import re + +# Third-Party +import pytest + + +# Migration module information +OBSERVABILITY_MIGRATIONS = [ + { + "module": "mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables", + "revision": "a23a08d61eb0", + "down_revision": "a706a3320c56", + "description": "add_observability_tables", + }, + { + "module": "mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes", + "revision": "i3c4d5e6f7g8", + "down_revision": "a23a08d61eb0", + "description": "add observability performance indexes", + }, + { + "module": "mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries", + "revision": "j4d5e6f7g8h9", + "down_revision": "i3c4d5e6f7g8", + "description": "add observability saved queries", + }, +] + + +class TestObservabilityMigrationModules: + """Test that all observability migration modules are valid.""" + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_module_imports(self, migration_info): + """Test that migration module can be imported.""" + module_name = migration_info["module"] + + try: + module = importlib.import_module(module_name) + assert module is not None, f"Module {module_name} imported as None" + except ImportError as e: + pytest.fail(f"Failed to import {module_name}: {e}") + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_has_upgrade_function(self, migration_info): + """Test that migration has an upgrade() function.""" + module_name = migration_info["module"] + module = importlib.import_module(module_name) + + assert hasattr(module, "upgrade"), f"{module_name} missing upgrade() function" + assert callable(module.upgrade), f"{module_name}.upgrade is not callable" + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_has_downgrade_function(self, migration_info): + """Test that migration has a downgrade() function.""" + module_name = migration_info["module"] + module = importlib.import_module(module_name) + + assert hasattr(module, "downgrade"), f"{module_name} missing downgrade() function" + assert callable(module.downgrade), f"{module_name}.downgrade is not callable" + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_revision_id_correct(self, migration_info): + """Test that migration has correct revision ID.""" + module_name = migration_info["module"] + expected_revision = migration_info["revision"] + + module = importlib.import_module(module_name) + + assert hasattr(module, "revision"), f"{module_name} missing revision variable" + assert module.revision == expected_revision, f"{module_name} has incorrect revision: {module.revision} != {expected_revision}" + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_down_revision_correct(self, migration_info): + """Test that migration has correct down_revision.""" + module_name = migration_info["module"] + expected_down_revision = migration_info["down_revision"] + + module = importlib.import_module(module_name) + + assert hasattr(module, "down_revision"), f"{module_name} missing down_revision variable" + assert module.down_revision == expected_down_revision, f"{module_name} has incorrect down_revision: {module.down_revision} != {expected_down_revision}" + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_functions_have_no_parameters(self, migration_info): + """Test that upgrade() and downgrade() accept no parameters.""" + module_name = migration_info["module"] + module = importlib.import_module(module_name) + + # Check upgrade function signature + upgrade_sig = pyinspect.signature(module.upgrade) + assert len(upgrade_sig.parameters) == 0, f"{module_name}.upgrade() should have no parameters" + + # Check downgrade function signature + downgrade_sig = pyinspect.signature(module.downgrade) + assert len(downgrade_sig.parameters) == 0, f"{module_name}.downgrade() should have no parameters" + + +class TestObservabilityTablesMigration: + """Test migration a23a08d61eb0 (add observability tables).""" + + def test_creates_four_tables(self): + """Test that migration creates 4 observability tables.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + # Get source code + source = pyinspect.getsource(module.upgrade) + + # Count create_table calls + create_table_count = source.count("op.create_table") + assert create_table_count == 4, f"Expected 4 create_table calls, found {create_table_count}" + + # Verify table names + assert "observability_traces" in source + assert "observability_spans" in source + assert "observability_events" in source + assert "observability_metrics" in source + + def test_downgrade_drops_four_tables(self): + """Test that downgrade drops all 4 tables.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + source = pyinspect.getsource(module.downgrade) + + drop_table_count = source.count("op.drop_table") + assert drop_table_count == 4, f"Expected 4 drop_table calls, found {drop_table_count}" + + def test_uses_datetime_with_timezone(self): + """Test that DateTime columns use timezone=True.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + source = pyinspect.getsource(module.upgrade) + + # Should use DateTime(timezone=True) + assert "DateTime(timezone=True)" in source, "Missing DateTime(timezone=True)" + + def test_uses_json_column_type(self): + """Test that JSON columns are used for attributes.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + source = pyinspect.getsource(module.upgrade) + + # Should use sa.JSON() + assert "sa.JSON()" in source, "Missing sa.JSON() column type" + + def test_foreign_keys_have_cascade_delete(self): + """Test that foreign keys have CASCADE delete.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + source = pyinspect.getsource(module.upgrade) + + # Should have ondelete="CASCADE" + assert 'ondelete="CASCADE"' in source, "Missing CASCADE delete on foreign keys" + + +class TestObservabilityPerformanceIndexes: + """Test migration i3c4d5e6f7g8 (add performance indexes).""" + + def test_uses_op_create_index_not_raw_sql(self): + """Test that migration uses op.create_index() instead of raw SQL.""" + module = importlib.import_module("mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes") + + source = pyinspect.getsource(module.upgrade) + + # Should use op.create_index + assert "op.create_index" in source, "Missing op.create_index calls" + + # Should NOT use raw SQL with IF NOT EXISTS + assert "CREATE INDEX IF NOT EXISTS" not in source, "Should not use raw SQL with IF NOT EXISTS" + assert "op.execute" not in source, "Should not use op.execute for index creation" + + def test_uses_op_drop_index_not_raw_sql(self): + """Test that downgrade uses op.drop_index() instead of raw SQL.""" + module = importlib.import_module("mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes") + + source = pyinspect.getsource(module.downgrade) + + # Should use op.drop_index + assert "op.drop_index" in source, "Missing op.drop_index calls" + + # Should NOT use raw SQL with IF EXISTS + assert "DROP INDEX IF EXISTS" not in source, "Should not use raw SQL with IF EXISTS" + assert "op.execute" not in source, "Should not use op.execute for index dropping" + + def test_creates_composite_indexes(self): + """Test that migration creates composite indexes.""" + module = importlib.import_module("mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes") + + source = pyinspect.getsource(module.upgrade) + + # Check for multi-column indexes + assert '["status", "start_time"]' in source or "['status', 'start_time']" in source, "Missing composite index on status+start_time" + assert '["trace_id", "start_time"]' in source or "['trace_id', 'start_time']" in source, "Missing composite index on trace_id+start_time" + + def test_downgrade_drops_indexes_in_reverse_order(self): + """Test that downgrade drops indexes (reverse order is good practice).""" + module = importlib.import_module("mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes") + + source = pyinspect.getsource(module.downgrade) + + # Count drop_index calls + drop_count = source.count("op.drop_index") + create_source = pyinspect.getsource(module.upgrade) + create_count = create_source.count("op.create_index") + + assert drop_count == create_count, f"Downgrade should drop {create_count} indexes, but drops {drop_count}" + + def test_specifies_table_name_in_drop_index(self): + """Test that op.drop_index includes table_name parameter.""" + module = importlib.import_module("mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes") + + source = pyinspect.getsource(module.downgrade) + + # Should specify table_name for cross-database compatibility + assert "table_name=" in source, "op.drop_index should specify table_name parameter" + + +class TestObservabilitySavedQueries: + """Test migration j4d5e6f7g8h9 (add saved queries table).""" + + def test_boolean_uses_sa_false_not_string(self): + """Test that Boolean server_default uses sa.false() not string '0'.""" + module = importlib.import_module("mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries") + + source = pyinspect.getsource(module.upgrade) + + # Should use sa.false() for Boolean + assert "sa.false()" in source, "Boolean server_default should use sa.false()" + + # Should NOT use string "0" for Boolean + assert 'sa.Boolean(), nullable=False, server_default="0"' not in source, "Should not use string '0' for Boolean server_default" + + def test_integer_uses_sa_text_for_default(self): + """Test that Integer server_default uses sa.text('0').""" + module = importlib.import_module("mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries") + + source = pyinspect.getsource(module.upgrade) + + # Should use sa.text("0") for Integer + assert 'sa.text("0")' in source, "Integer server_default should use sa.text('0')" + + def test_no_duplicate_user_email_index(self): + """Test that there's only ONE index on user_email column.""" + module = importlib.import_module("mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries") + + source = pyinspect.getsource(module.upgrade) + + # Count how many times we create an index on user_email + user_email_index_count = 0 + + # Look for index creation lines containing user_email + for line in source.split("\n"): + if "op.create_index" in line and "user_email" in line: + user_email_index_count += 1 + + assert user_email_index_count == 1, f"Expected 1 user_email index, found {user_email_index_count}" + + def test_downgrade_drops_correct_number_of_indexes(self): + """Test that downgrade drops the same number of indexes as upgrade creates.""" + module = importlib.import_module("mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries") + + upgrade_source = pyinspect.getsource(module.upgrade) + downgrade_source = pyinspect.getsource(module.downgrade) + + create_count = upgrade_source.count("op.create_index") + drop_count = downgrade_source.count("op.drop_index") + + assert drop_count == create_count, f"Downgrade should drop {create_count} indexes, but drops {drop_count}" + + def test_uses_current_timestamp_for_datetime_defaults(self): + """Test that DateTime columns use CURRENT_TIMESTAMP for server defaults.""" + module = importlib.import_module("mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries") + + source = pyinspect.getsource(module.upgrade) + + # Should use sa.text("CURRENT_TIMESTAMP") for DateTime + assert 'sa.text("CURRENT_TIMESTAMP")' in source, "DateTime columns should use sa.text('CURRENT_TIMESTAMP')" + + +class TestCrossDatabaseCompatibility: + """Test cross-database compatibility concerns.""" + + def test_no_mysql_specific_if_not_exists(self): + """Test that migrations don't use MySQL < 8.0.13 incompatible IF NOT EXISTS.""" + for migration_info in OBSERVABILITY_MIGRATIONS: + module = importlib.import_module(migration_info["module"]) + upgrade_source = pyinspect.getsource(module.upgrade) + downgrade_source = pyinspect.getsource(module.downgrade) + + # Should not use raw SQL with IF NOT EXISTS / IF EXISTS + assert "IF NOT EXISTS" not in upgrade_source, f"{migration_info['module']} uses IF NOT EXISTS (MySQL < 8.0.13 incompatible)" + assert "IF EXISTS" not in downgrade_source, f"{migration_info['module']} uses IF EXISTS (MySQL < 8.0.13 incompatible)" + + def test_uses_sqlalchemy_types_not_raw_sql_types(self): + """Test that migrations use SQLAlchemy types (sa.*) not raw SQL types.""" + for migration_info in OBSERVABILITY_MIGRATIONS: + module = importlib.import_module(migration_info["module"]) + source = pyinspect.getsource(module.upgrade) + + # Should use sa.String, sa.Integer, etc. + if "create_table" in source: + assert "sa.String" in source or "sa.Text" in source or "sa.Integer" in source, f"{migration_info['module']} should use SQLAlchemy types" + + def test_datetime_columns_use_timezone_parameter(self): + """Test that DateTime columns specify timezone parameter.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + source = pyinspect.getsource(module.upgrade) + + # All DateTime columns should specify timezone=True + datetime_matches = re.findall(r"sa\.DateTime\([^)]*\)", source) + + for match in datetime_matches: + assert "timezone=True" in match, f"DateTime column missing timezone parameter: {match}" + + +class TestMigrationChain: + """Test that migrations form a proper chain.""" + + def test_migrations_form_continuous_chain(self): + """Test that down_revision of each migration matches previous revision.""" + # Check that chain is continuous + revisions = {m["revision"]: m["down_revision"] for m in OBSERVABILITY_MIGRATIONS} + + # i3c4d5e6f7g8 should depend on a23a08d61eb0 + assert revisions["i3c4d5e6f7g8"] == "a23a08d61eb0" + + # j4d5e6f7g8h9 should depend on i3c4d5e6f7g8 + assert revisions["j4d5e6f7g8h9"] == "i3c4d5e6f7g8" + + def test_no_circular_dependencies(self): + """Test that there are no circular dependencies in migration chain.""" + revisions = {m["revision"]: m["down_revision"] for m in OBSERVABILITY_MIGRATIONS} + + # Build dependency graph and check for cycles + visited = set() + + for revision in revisions: + path = [] + current = revision + + while current and current not in visited: + if current in path: + pytest.fail(f"Circular dependency detected: {' -> '.join(path + [current])}") + path.append(current) + current = revisions.get(current) + + visited.update(path) + + def test_all_migrations_have_unique_revisions(self): + """Test that all migration revisions are unique.""" + revisions = [m["revision"] for m in OBSERVABILITY_MIGRATIONS] + + assert len(revisions) == len(set(revisions)), "Duplicate revision IDs found" diff --git a/tests/unit/mcpgateway/plugins/agent/__init__.py b/tests/unit/mcpgateway/plugins/agent/__init__.py new file mode 100644 index 000000000..5503bed0d --- /dev/null +++ b/tests/unit/mcpgateway/plugins/agent/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/agent/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for agent plugin framework. +""" diff --git a/tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py b/tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py new file mode 100644 index 000000000..ac7f480e2 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py @@ -0,0 +1,365 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for agent plugin framework. +""" + +# Third-Party +import pytest + +# First-Party +from mcpgateway.common.models import Message, Role, TextContent +from mcpgateway.plugins.framework import GlobalContext, PluginManager, PluginViolationError +from mcpgateway.plugins.framework import ( + AgentHookType, + AgentPreInvokePayload, + AgentPostInvokePayload, +) + + +@pytest.mark.asyncio +async def test_agent_passthrough_plugin(): + """Test that passthrough agent plugin works correctly.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml") + await manager.initialize() + + # Verify plugin loaded + assert manager.config.plugins[0].name == "PassThroughAgent" + assert manager.config.plugins[0].kind == "tests.unit.mcpgateway.plugins.fixtures.plugins.agent_plugins.PassThroughAgentPlugin" + assert AgentHookType.AGENT_PRE_INVOKE.value in manager.config.plugins[0].hooks + assert AgentHookType.AGENT_POST_INVOKE.value in manager.config.plugins[0].hooks + + # Create test payload + messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Hello agent!")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=messages, + tools=["search", "calculator"], + model="claude-3-5-sonnet-20241022" + ) + + # Invoke pre-hook + global_context = GlobalContext(request_id="test-req-1") + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Verify passthrough (no modification) + assert result.continue_processing is True + assert result.modified_payload is None + assert result.violation is None + + # Create response payload + response_messages = [ + Message(role=Role.ASSISTANT, content=TextContent(type="text", text="Hello user!")) + ] + post_payload = AgentPostInvokePayload( + agent_id="test-agent", + messages=response_messages + ) + + # Invoke post-hook + result, _ = await manager.invoke_hook( + AgentHookType.AGENT_POST_INVOKE, + post_payload, + global_context=global_context, + local_contexts=contexts + ) + + # Verify passthrough (no modification) + assert result.continue_processing is True + assert result.modified_payload is None + assert result.violation is None + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_filter_plugin_pre_invoke(): + """Test that filter agent plugin blocks messages with banned words in pre-invoke.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml") + await manager.initialize() + + # Create test payload with clean message + clean_messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Hello agent!")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=clean_messages + ) + + # Invoke pre-hook with clean message + global_context = GlobalContext(request_id="test-req-2") + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Clean message should pass through + assert result.continue_processing is True + assert result.modified_payload is None + + # Create payload with blocked word + blocked_messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Click here for spam offers!")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=blocked_messages + ) + + # Invoke pre-hook with blocked message - should raise violation + with pytest.raises(PluginViolationError) as exc_info: + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context, + violations_as_exceptions=True + ) + + assert exc_info.value.violation.code == "BLOCKED_CONTENT" + assert "blocked content" in exc_info.value.violation.reason.lower() + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_filter_plugin_post_invoke(): + """Test that filter agent plugin blocks messages with banned words in post-invoke.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml") + await manager.initialize() + + # Create test payload with clean response + clean_messages = [ + Message(role=Role.ASSISTANT, content=TextContent(type="text", text="Here is your answer.")) + ] + payload = AgentPostInvokePayload( + agent_id="test-agent", + messages=clean_messages + ) + + # Invoke post-hook with clean message + global_context = GlobalContext(request_id="test-req-3") + result, _ = await manager.invoke_hook( + AgentHookType.AGENT_POST_INVOKE, + payload, + global_context=global_context + ) + + # Clean message should pass through + assert result.continue_processing is True + assert result.modified_payload is None + + # Create payload with blocked word + blocked_messages = [ + Message(role=Role.ASSISTANT, content=TextContent(type="text", text="This looks like malware to me.")) + ] + payload = AgentPostInvokePayload( + agent_id="test-agent", + messages=blocked_messages + ) + + # Invoke post-hook with blocked message - should raise violation + with pytest.raises(PluginViolationError) as exc_info: + result, _ = await manager.invoke_hook( + AgentHookType.AGENT_POST_INVOKE, + payload, + global_context=global_context, + violations_as_exceptions=True + ) + + assert exc_info.value.violation.code == "BLOCKED_CONTENT" + assert "blocked content" in exc_info.value.violation.reason.lower() + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_filter_plugin_partial_filtering(): + """Test that filter plugin removes only blocked messages, keeps others.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml") + await manager.initialize() + + # Create payload with mixed messages + mixed_messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Hello agent!")), + Message(role=Role.USER, content=TextContent(type="text", text="Check out this spam!")), + Message(role=Role.USER, content=TextContent(type="text", text="What's the weather?")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=mixed_messages + ) + + # Invoke pre-hook + global_context = GlobalContext(request_id="test-req-4") + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Should have modified payload with only 2 messages + assert result.modified_payload is not None + assert len(result.modified_payload.messages) == 2 + assert result.modified_payload.messages[0].content.text == "Hello agent!" + assert result.modified_payload.messages[1].content.text == "What's the weather?" + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_context_persistence(): + """Test that local context persists between pre and post hooks.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml") + await manager.initialize() + + # Create pre-invoke payload + messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Hello!")) + ] + pre_payload = AgentPreInvokePayload( + agent_id="test-agent-123", + messages=messages + ) + + # Invoke pre-hook + global_context = GlobalContext(request_id="test-req-5") + pre_result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + pre_payload, + global_context=global_context + ) + + assert pre_result.continue_processing is True + + # Create post-invoke payload + response_messages = [ + Message(role=Role.ASSISTANT, content=TextContent(type="text", text="Hi there!")) + ] + post_payload = AgentPostInvokePayload( + agent_id="test-agent-123", + messages=response_messages + ) + + # Invoke post-hook with same contexts + post_result, _ = await manager.invoke_hook( + AgentHookType.AGENT_POST_INVOKE, + post_payload, + global_context=global_context, + local_contexts=contexts + ) + + # Verify context was verified (metadata added by post hook) + assert post_result.continue_processing is True + # The metadata should be in the contexts, not the result + # Check that invocation_count was incremented + assert contexts is not None + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_plugin_with_tools(): + """Test agent plugin with tools list.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml") + await manager.initialize() + + # Create payload with tools + messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Search for Python tutorials")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=messages, + tools=["web_search", "code_search", "calculator"] + ) + + # Invoke pre-hook + global_context = GlobalContext(request_id="test-req-6") + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Verify tools are preserved + assert result.continue_processing is True + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_plugin_with_model_override(): + """Test agent plugin with model override.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml") + await manager.initialize() + + # Create payload with model override + messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Analyze this code")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=messages, + model="claude-3-opus-20240229", + parameters={"temperature": 0.7, "max_tokens": 1000} + ) + + # Invoke pre-hook + global_context = GlobalContext(request_id="test-req-7") + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Verify model and parameters are preserved + assert result.continue_processing is True + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_plugin_with_tool_calls(): + """Test agent plugin with tool calls in response.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml") + await manager.initialize() + + # Create post-invoke payload with tool calls + messages = [ + Message(role=Role.ASSISTANT, content=TextContent(type="text", text="I'll search for that.")) + ] + tool_calls = [ + { + "name": "web_search", + "arguments": {"query": "Python tutorials", "num_results": 5} + } + ] + payload = AgentPostInvokePayload( + agent_id="test-agent", + messages=messages, + tool_calls=tool_calls + ) + + # Invoke post-hook + global_context = GlobalContext(request_id="test-req-8") + result, _ = await manager.invoke_hook( + AgentHookType.AGENT_POST_INVOKE, + payload, + global_context=global_context + ) + + # Verify tool calls are preserved + assert result.continue_processing is True + + await manager.shutdown() diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml new file mode 100644 index 000000000..68d7f400f --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml @@ -0,0 +1,29 @@ +plugins: + - name: ContextTrackingAgent + kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_plugins.ContextTrackingAgentPlugin + description: An agent plugin that tracks state in local context + version: "1.0.0" + author: Test Suite + hooks: + - agent_pre_invoke + - agent_post_invoke + tags: + - test + - agent + - context + mode: enforce + priority: 50 + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: true + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml new file mode 100644 index 000000000..9d31a5061 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml @@ -0,0 +1,34 @@ +plugins: + - name: MessageFilterAgent + kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_plugins.MessageFilterAgentPlugin + description: An agent plugin that filters blocked words + version: "1.0.0" + author: Test Suite + hooks: + - agent_pre_invoke + - agent_post_invoke + tags: + - test + - agent + - filter + mode: enforce + priority: 50 + config: + blocked_words: + - spam + - malware + - phishing + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: true + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml new file mode 100644 index 000000000..3a6ec9c16 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml @@ -0,0 +1,28 @@ +plugins: + - name: PassThroughAgent + kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_plugins.PassThroughAgentPlugin + description: A simple pass-through agent plugin for testing + version: "1.0.0" + author: Test Suite + hooks: + - agent_pre_invoke + - agent_post_invoke + tags: + - test + - agent + mode: enforce + priority: 50 + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: true + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml new file mode 100644 index 000000000..072952ded --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml @@ -0,0 +1,26 @@ +plugins: + - name: DemoPlugin + kind: test_hook_patterns.DemoPlugin + description: Demonstration plugin showing all three hook patterns + version: "1.0.0" + author: Demo + hooks: + - tool_pre_invoke + - tool_post_invoke + - email_pre_send + tags: + - demo + - test + mode: enforce + priority: 50 + +# Plugin directories to scan (not needed for this demo) +plugin_dirs: [] + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: true + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/agent_plugins.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/agent_plugins.py new file mode 100644 index 000000000..7112a2c11 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/agent_plugins.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/fixtures/plugins/agent_plugins.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Test agent plugins for unit testing. +""" + +# First-Party +from mcpgateway.common.models import Message, Role, TextContent +from mcpgateway.plugins.framework import ( + Plugin, + PluginContext, + AgentPreInvokePayload, + AgentPreInvokeResult, + AgentPostInvokePayload, + AgentPostInvokeResult, +) + + +class PassThroughAgentPlugin(Plugin): + """A simple pass-through agent plugin that doesn't modify anything.""" + + async def agent_pre_invoke( + self, payload: AgentPreInvokePayload, context: PluginContext + ) -> AgentPreInvokeResult: + """Pass through without modification. + + Args: + payload: The agent pre-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result allowing processing to continue. + """ + return AgentPreInvokeResult(continue_processing=True) + + async def agent_post_invoke( + self, payload: AgentPostInvokePayload, context: PluginContext + ) -> AgentPostInvokeResult: + """Pass through without modification. + + Args: + payload: The agent post-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result allowing processing to continue. + """ + return AgentPostInvokeResult(continue_processing=True) + + +class MessageFilterAgentPlugin(Plugin): + """An agent plugin that filters messages containing blocked words.""" + + async def agent_pre_invoke( + self, payload: AgentPreInvokePayload, context: PluginContext + ) -> AgentPreInvokeResult: + """Filter messages containing blocked words. + + Args: + payload: The agent pre-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result with filtered messages or violation. + """ + blocked_words = self.config.config.get("blocked_words", []) + + # Filter messages + filtered_messages = [] + for msg in payload.messages: + if isinstance(msg.content, TextContent): + text_lower = msg.content.text.lower() + if any(word in text_lower for word in blocked_words): + # Skip this message + continue + filtered_messages.append(msg) + + # If all messages were blocked, return violation + if not filtered_messages and payload.messages: + from mcpgateway.plugins.framework import PluginViolation + return AgentPreInvokeResult( + continue_processing=False, + violation=PluginViolation( + code="BLOCKED_CONTENT", + reason="All messages contained blocked content", + description="This is a test of content blocking" + ) + ) + + # Return modified payload if messages were filtered + if len(filtered_messages) != len(payload.messages): + modified_payload = AgentPreInvokePayload( + agent_id=payload.agent_id, + messages=filtered_messages, + tools=payload.tools, + headers=payload.headers, + model=payload.model, + system_prompt=payload.system_prompt, + parameters=payload.parameters + ) + return AgentPreInvokeResult(modified_payload=modified_payload) + + return AgentPreInvokeResult(continue_processing=True) + + async def agent_post_invoke( + self, payload: AgentPostInvokePayload, context: PluginContext + ) -> AgentPostInvokeResult: + """Filter response messages containing blocked words. + + Args: + payload: The agent post-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result with filtered messages or violation. + """ + blocked_words = self.config.config.get("blocked_words", []) + + # Filter messages + filtered_messages = [] + for msg in payload.messages: + if isinstance(msg.content, TextContent): + text_lower = msg.content.text.lower() + if any(word in text_lower for word in blocked_words): + # Skip this message + continue + filtered_messages.append(msg) + + # If all messages were blocked, return violation + if not filtered_messages and payload.messages: + from mcpgateway.plugins.framework import PluginViolation + return AgentPostInvokeResult( + continue_processing=False, + violation=PluginViolation( + code="BLOCKED_CONTENT", + reason="All response messages contained blocked content", + description="This is a test of content blocking" + ) + ) + + # Return modified payload if messages were filtered + if len(filtered_messages) != len(payload.messages): + modified_payload = AgentPostInvokePayload( + agent_id=payload.agent_id, + messages=filtered_messages, + tool_calls=payload.tool_calls + ) + return AgentPostInvokeResult(modified_payload=modified_payload) + + return AgentPostInvokeResult(continue_processing=True) + + +class ContextTrackingAgentPlugin(Plugin): + """An agent plugin that tracks state in local context.""" + + async def agent_pre_invoke( + self, payload: AgentPreInvokePayload, context: PluginContext + ) -> AgentPreInvokeResult: + """Track invocation count in local context. + + Args: + payload: The agent pre-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result with updated local context. + """ + # Increment counter in local context + counter = context.metadata.get("invocation_count", 0) + context.metadata["invocation_count"] = counter + 1 + context.metadata["agent_id"] = payload.agent_id + + return AgentPreInvokeResult(continue_processing=True) + + async def agent_post_invoke( + self, payload: AgentPostInvokePayload, context: PluginContext + ) -> AgentPostInvokeResult: + """Verify context persists from pre-invoke. + + Args: + payload: The agent post-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result after verifying context. + """ + # Verify context persisted + counter = context.metadata.get("invocation_count", 0) + agent_id = context.metadata.get("agent_id", "") + + # Add metadata about the context + context.metadata["context_verified"] = counter > 0 and agent_id == payload.agent_id + + return AgentPostInvokeResult(continue_processing=True) diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py index eef673450..e8e251ebb 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py @@ -9,8 +9,8 @@ """ from mcpgateway.plugins.framework import ( - Plugin, PluginContext, + Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py index d15f110c1..32279ad2d 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py @@ -9,8 +9,8 @@ """ from mcpgateway.plugins.framework import ( - Plugin, PluginContext, + Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py index 1ba97649d..00b95faa0 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py @@ -13,9 +13,9 @@ from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA from mcpgateway.plugins.framework import ( - HttpHeaderPayload, - Plugin, PluginContext, + Plugin, + HttpHeaderPayload, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py index 8a6db5869..9f6c4b3d2 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py @@ -9,8 +9,8 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginContext, + Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/simple.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/simple.py new file mode 100644 index 000000000..287fc3ab5 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/simple.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/fixtures/plugins/simple.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Test Suite + +Simple minimal plugins for testing the plugin framework. +These plugins provide basic passthrough implementations for testing +registration, priority sorting, hook filtering, etc. +""" + +# First-Party +from mcpgateway.plugins.framework import ( + Plugin, + PluginContext, + PromptPosthookPayload, + PromptPosthookResult, + PromptPrehookPayload, + PromptPrehookResult, + ToolPostInvokePayload, + ToolPostInvokeResult, + ToolPreInvokePayload, + ToolPreInvokeResult, +) + + +class SimplePromptPlugin(Plugin): + """Minimal plugin with prompt hooks for testing.""" + + async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: + """Passthrough prompt pre-fetch hook.""" + return PromptPrehookResult(continue_processing=True) + + async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: + """Passthrough prompt post-fetch hook.""" + return PromptPosthookResult(continue_processing=True) + + +class SimpleToolPlugin(Plugin): + """Minimal plugin with tool hooks for testing.""" + + async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: + """Passthrough tool pre-invoke hook.""" + return ToolPreInvokeResult(continue_processing=True) + + async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: + """Passthrough tool post-invoke hook.""" + return ToolPostInvokeResult(continue_processing=True) diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py index 288275e8f..1d675a70f 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py @@ -14,7 +14,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework import ( GlobalContext, PluginContext, diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py index 6c960ce51..5c6267ebf 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py @@ -17,11 +17,14 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, ResourceContent, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, ResourceContent, Role, TextContent from mcpgateway.plugins.framework import ( ConfigLoader, GlobalContext, PluginContext, + PromptHookType, + ResourceHookType, + ToolHookType, PromptPosthookPayload, PromptPrehookPayload, ResourcePostFetchPayload, @@ -121,35 +124,35 @@ async def test_hook_methods_empty_content(): # Test prompt_pre_fetch with empty content - should raise PluginError payload = PromptPrehookPayload(prompt_id="1", args={}) with pytest.raises(PluginError): - await plugin.prompt_pre_fetch(payload, context) + await plugin.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, context) # Test prompt_post_fetch with empty content - should raise PluginError message = Message(content=TextContent(type="text", text="test"), role=Role.USER) prompt_result = PromptResult(messages=[message]) payload = PromptPosthookPayload(prompt_id="1", result=prompt_result) with pytest.raises(PluginError): - await plugin.prompt_post_fetch(payload, context) + await plugin.invoke_hook(PromptHookType.PROMPT_POST_FETCH, payload, context) # Test tool_pre_invoke with empty content - should raise PluginError payload = ToolPreInvokePayload(name="test", args={}) with pytest.raises(PluginError): - await plugin.tool_pre_invoke(payload, context) + await plugin.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, payload, context) # Test tool_post_invoke with empty content - should raise PluginError payload = ToolPostInvokePayload(name="test", result={}) with pytest.raises(PluginError): - await plugin.tool_post_invoke(payload, context) + await plugin.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, context) # Test resource_pre_fetch with empty content - should raise PluginError payload = ResourcePreFetchPayload(uri="file://test.txt") with pytest.raises(PluginError): - await plugin.resource_pre_fetch(payload, context) + await plugin.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, payload, context) # Test resource_post_fetch with empty content - should raise PluginError resource_content = ResourceContent(type="resource", id="123",uri="file://test.txt", text="content") payload = ResourcePostFetchPayload(uri="file://test.txt", content=resource_content) with pytest.raises(PluginError): - await plugin.resource_post_fetch(payload, context) + await plugin.invoke_hook(ResourceHookType.RESOURCE_POST_FETCH, payload, context) await plugin.shutdown() diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py index 53f5f8e2b..5b3ea2538 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py @@ -20,7 +20,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, ResourceContent, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, ResourceContent, Role, TextContent from mcpgateway.plugins.framework import ( ConfigLoader, GlobalContext, @@ -29,6 +29,9 @@ PluginContext, PluginLoader, PluginManager, + PromptHookType, + ResourceHookType, + ToolHookType, PromptPosthookPayload, PromptPrehookPayload, ResourcePostFetchPayload, @@ -48,7 +51,7 @@ async def test_client_load_stdio(): loader = PluginLoader() plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"text": "That was innovative!"}) - result = await plugin.prompt_pre_fetch(prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) + result = await plugin.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) assert result.violation assert result.violation.reason == "Prompt not allowed" assert result.violation.description == "A deny word was found in the prompt" @@ -72,7 +75,7 @@ async def test_client_load_stdio_overrides(): loader = PluginLoader() plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) prompt = PromptPrehookPayload(prompt_id="test_prompt", args = {"text": "That was innovative!"}) - result = await plugin.prompt_pre_fetch(prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) + result = await plugin.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) assert result.violation assert result.violation.reason == "Prompt not allowed" assert result.violation.description == "A deny word was found in the prompt" @@ -98,7 +101,7 @@ async def test_client_load_stdio_post_prompt(): plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) prompt = PromptPrehookPayload(prompt_id="test_prompt", args = {"user": "What a crapshow!"}) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) - result = await plugin.prompt_pre_fetch(prompt, context) + result = await plugin.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, context) assert result.modified_payload.args["user"] == "What a yikesshow!" config = plugin.config assert config.name == "ReplaceBadWordsPlugin" @@ -111,7 +114,7 @@ async def test_client_load_stdio_post_prompt(): payload_result = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result = await plugin.prompt_post_fetch(payload_result, context=context) + result = await plugin.invoke_hook(PromptHookType.PROMPT_POST_FETCH, payload_result, context=context) assert len(result.modified_payload.result.messages) == 1 assert result.modified_payload.result.messages[0].content.text == "What the yikes?" await plugin.shutdown() @@ -185,7 +188,7 @@ async def test_hooks(): await plugin_manager.initialize() payload = PromptPrehookPayload(prompt_id="test_prompt", name="test_prompt", args={"arg0": "This is a crap argument"}) global_context = GlobalContext(request_id="1") - result, _ = await plugin_manager.prompt_pre_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing """Test prompt post hook across all registered plugins.""" @@ -193,31 +196,31 @@ async def test_hooks(): message = Message(content=TextContent(type="text", text="prompt"), role=Role.USER) prompt_result = PromptResult(messages=[message]) payload = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result, _ = await plugin_manager.prompt_post_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(PromptHookType.PROMPT_POST_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing """Test tool pre hook across all registered plugins.""" # Customize payload for testing payload = ToolPreInvokePayload(name="test_prompt", args={"arg0": "This is an argument"}) - result, _ = await plugin_manager.tool_pre_invoke(payload, global_context) + result, _ = await plugin_manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, payload, global_context) # Assert expected behaviors assert result.continue_processing """Test tool post hook across all registered plugins.""" # Customize payload for testing payload = ToolPostInvokePayload(name="test_tool", result={"output0": "output value"}) - result, _ = await plugin_manager.tool_post_invoke(payload, global_context) + result, _ = await plugin_manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, global_context) # Assert expected behaviors assert result.continue_processing payload = ResourcePreFetchPayload(uri="file:///data.txt") - result, _ = await plugin_manager.resource_pre_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing content = ResourceContent(type="resource", id="123", uri="file:///data.txt", text="Hello World") payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) - result, _ = await plugin_manager.resource_post_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(ResourceHookType.RESOURCE_POST_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing await plugin_manager.shutdown() @@ -233,7 +236,7 @@ async def test_errors(): global_context = GlobalContext(request_id="1") escaped_regex = re.escape("ValueError('Sadly! Prompt prefetch is broken!')") with pytest.raises(PluginError, match=escaped_regex): - await plugin_manager.prompt_pre_fetch(payload, global_context) + await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) await plugin_manager.shutdown() @@ -250,7 +253,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert len(contexts) == 2 ctxs = [contexts[key] for key in contexts.keys()] @@ -279,7 +282,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): assert result.modified_payload is None # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, contexts = await manager.tool_post_invoke(tool_result_payload, global_context=global_context, local_contexts=contexts) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) ctxs = [contexts[key] for key in contexts.keys()] assert len(ctxs) == 2 diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py index 72fdf82f6..05dcbfbd4 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py @@ -17,7 +17,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework import ConfigLoader, GlobalContext, PluginContext, PluginLoader, PromptPosthookPayload, PromptPrehookPayload diff --git a/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_patterns.py b/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_patterns.py new file mode 100644 index 000000000..11291fdae --- /dev/null +++ b/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_patterns.py @@ -0,0 +1,312 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/framework/hooks/test_hook_patterns.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests demonstrating three hook patterns in the plugin framework: +1. Convention-based: method name matches hook type +2. Decorator-based: @hook decorator with custom method name +3. Custom hook: @hook decorator with new hook type + payload/result types +""" + +# Third-Party +import pytest + +# First-Party +from mcpgateway.plugins.framework import ( + Plugin, + PluginContext, + GlobalContext, + PluginManager, + PluginPayload, + PluginResult, + ToolHookType, + ToolPreInvokePayload, + ToolPreInvokeResult, + ToolPostInvokePayload, + ToolPostInvokeResult, +) +from mcpgateway.plugins.framework.decorator import hook + + +# ========== Custom Hook Definition ========== +class EmailPayload(PluginPayload): + """Payload for email hook.""" + + recipient: str + subject: str + body: str + + +class EmailResult(PluginResult[EmailPayload]): + """Result for email hook.""" + + pass + + +# ========== Demo Plugin with All Three Patterns ========== +class DemoPlugin(Plugin): + """Demo plugin showing all three hook patterns.""" + + # Pattern 1: Convention-based (method name matches hook type) + async def tool_pre_invoke( + self, payload: ToolPreInvokePayload, context: PluginContext + ) -> ToolPreInvokeResult: + """Pattern 1: Convention-based hook. + + This method is found automatically because its name matches + the hook type 'tool_pre_invoke'. + """ + # Modify the payload + modified_payload = ToolPreInvokePayload( + name=payload.name, + args={**payload.args, "pattern": "convention"}, + headers=payload.headers, + ) + + return ToolPreInvokeResult( + modified_payload=modified_payload, + metadata={"pattern": "convention", "hook": "tool_pre_invoke"} + ) + + # Pattern 2: Decorator-based with custom method name + @hook(ToolHookType.TOOL_POST_INVOKE) + async def my_custom_tool_post_handler( + self, payload: ToolPostInvokePayload, context: PluginContext + ) -> ToolPostInvokeResult: + """Pattern 2: Decorator-based hook with custom method name. + + This method is found via the @hook decorator even though + the method name doesn't match the hook type. + """ + # Modify the result + modified_result = {**payload.result, "pattern": "decorator"} if isinstance(payload.result, dict) else payload.result + + modified_payload = ToolPostInvokePayload( + name=payload.name, + result=modified_result, + ) + + return ToolPostInvokeResult( + modified_payload=modified_payload, + metadata={"pattern": "decorator", "hook": "tool_post_invoke"} + ) + + # Pattern 3: Custom hook with payload and result types + @hook("email_pre_send", EmailPayload, EmailResult) + async def validate_email( + self, payload: EmailPayload, context: PluginContext + ) -> EmailResult: + """Pattern 3: Custom hook with new hook type. + + This registers a completely new hook type 'email_pre_send' + with its own payload and result types. + """ + # Validate email + if "@" not in payload.recipient: + modified_payload = EmailPayload( + recipient=f"{payload.recipient}@example.com", + subject=payload.subject, + body=payload.body, + ) + return EmailResult( + modified_payload=modified_payload, + metadata={"pattern": "custom", "hook": "email_pre_send", "fixed_email": True} + ) + + return EmailResult( + continue_processing=True, + metadata={"pattern": "custom", "hook": "email_pre_send"} + ) + + +# ========== Pytest Tests ========== +@pytest.mark.asyncio +async def test_pattern_1_convention_based_hook(): + """Test Pattern 1: Convention-based hook (method name matches hook type).""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml") + await manager.initialize() + + # Create payload for tool_pre_invoke + payload = ToolPreInvokePayload( + name="my_calculator", + args={"operation": "add", "a": 5, "b": 3} + ) + + global_context = GlobalContext(request_id="test-1") + + # Invoke the hook + result, contexts = await manager.invoke_hook( + ToolHookType.TOOL_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Assertions + assert result is not None + assert result.continue_processing is True + assert result.modified_payload is not None + assert result.modified_payload.name == "my_calculator" + assert result.modified_payload.args["operation"] == "add" + assert result.modified_payload.args["a"] == 5 + assert result.modified_payload.args["b"] == 3 + assert result.modified_payload.args["pattern"] == "convention" # Added by hook + assert result.metadata is not None + assert result.metadata["pattern"] == "convention" + assert result.metadata["hook"] == "tool_pre_invoke" + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_pattern_2_decorator_based_hook(): + """Test Pattern 2: Decorator-based hook with custom method name.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml") + await manager.initialize() + + # Create payload for tool_post_invoke + payload = ToolPostInvokePayload( + name="my_calculator", + result={"sum": 8, "status": "success"} + ) + + global_context = GlobalContext(request_id="test-2") + + # Invoke the hook + result, contexts = await manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, + payload, + global_context=global_context + ) + + # Assertions + assert result is not None + assert result.continue_processing is True + assert result.modified_payload is not None + assert result.modified_payload.name == "my_calculator" + assert result.modified_payload.result["sum"] == 8 + assert result.modified_payload.result["status"] == "success" + assert result.modified_payload.result["pattern"] == "decorator" # Added by hook + assert result.metadata is not None + assert result.metadata["pattern"] == "decorator" + assert result.metadata["hook"] == "tool_post_invoke" + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_pattern_3_custom_hook_valid_email(): + """Test Pattern 3: Custom hook with new hook type (valid email).""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml") + await manager.initialize() + + # Test with valid email + payload = EmailPayload( + recipient="user@example.com", + subject="Test Email", + body="This is a test." + ) + + global_context = GlobalContext(request_id="test-3a") + + result, contexts = await manager.invoke_hook( + "email_pre_send", + payload, + global_context=global_context + ) + + # Assertions + assert result is not None + assert result.continue_processing is True + assert result.modified_payload is None # No modification needed for valid email + assert result.metadata is not None + assert result.metadata["pattern"] == "custom" + assert result.metadata["hook"] == "email_pre_send" + assert "fixed_email" not in result.metadata # Email was already valid + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_pattern_3_custom_hook_invalid_email(): + """Test Pattern 3: Custom hook with new hook type (invalid email gets fixed).""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml") + await manager.initialize() + + # Test with invalid email (missing @) + payload = EmailPayload( + recipient="invalid-email", + subject="Test Email 2", + body="This email address needs fixing." + ) + + global_context = GlobalContext(request_id="test-3b") + + result, contexts = await manager.invoke_hook( + "email_pre_send", + payload, + global_context=global_context + ) + + # Assertions + assert result is not None + assert result.continue_processing is True + assert result.modified_payload is not None + assert result.modified_payload.recipient == "invalid-email@example.com" # Fixed by hook + assert result.modified_payload.subject == "Test Email 2" + assert result.modified_payload.body == "This email address needs fixing." + assert result.metadata is not None + assert result.metadata["pattern"] == "custom" + assert result.metadata["hook"] == "email_pre_send" + assert result.metadata["fixed_email"] is True # Hook fixed the email + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_all_three_patterns_in_sequence(): + """Test all three patterns work together in the same plugin manager.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml") + await manager.initialize() + + global_context = GlobalContext(request_id="test-all") + + # Test Pattern 1: Convention-based + payload1 = ToolPreInvokePayload( + name="test_tool", + args={"param": "value"} + ) + result1, _ = await manager.invoke_hook( + ToolHookType.TOOL_PRE_INVOKE, + payload1, + global_context=global_context + ) + assert result1.modified_payload.args["pattern"] == "convention" + + # Test Pattern 2: Decorator-based + payload2 = ToolPostInvokePayload( + name="test_tool", + result={"data": "output"} + ) + result2, _ = await manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, + payload2, + global_context=global_context + ) + assert result2.modified_payload.result["pattern"] == "decorator" + + # Test Pattern 3: Custom hook + payload3 = EmailPayload( + recipient="test", + subject="Test", + body="Test" + ) + result3, _ = await manager.invoke_hook( + "email_pre_send", + payload3, + global_context=global_context + ) + assert result3.modified_payload.recipient == "test@example.com" + + await manager.shutdown() diff --git a/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_registry.py b/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_registry.py new file mode 100644 index 000000000..c54a05770 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_registry.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 © IBM Corporation +SPDX-License-Identifier: Apache-2.0 + +Test suite for hook registry functionality. +""" + +# Third-Party +import pytest + +# First-Party +from mcpgateway.plugins.framework import ( + get_hook_registry, + AgentHookType, + PromptHookType, + ResourceHookType, + ToolHookType, + PromptPrehookPayload, + PromptPrehookResult, + PromptPosthookPayload, + PromptPosthookResult, + ToolPreInvokePayload, + ToolPreInvokeResult, +) + + +class TestHookRegistry: + """Test cases for the HookRegistry class.""" + + @pytest.fixture + def registry(self): + """Provide a hook registry instance.""" + return get_hook_registry() + + def test_mcp_hooks_are_registered(self, registry): + """Test that all MCP hooks are registered.""" + assert registry.is_registered(PromptHookType.PROMPT_PRE_FETCH) + assert registry.is_registered(PromptHookType.PROMPT_POST_FETCH) + assert registry.is_registered(ToolHookType.TOOL_PRE_INVOKE) + assert registry.is_registered(ToolHookType.TOOL_POST_INVOKE) + assert registry.is_registered(ResourceHookType.RESOURCE_PRE_FETCH) + assert registry.is_registered(ResourceHookType.RESOURCE_POST_FETCH) + + def test_get_payload_type(self, registry): + """Test retrieving payload types from registry.""" + payload_type = registry.get_payload_type(PromptHookType.PROMPT_PRE_FETCH) + assert payload_type == PromptPrehookPayload + + payload_type = registry.get_payload_type(PromptHookType.PROMPT_POST_FETCH) + assert payload_type == PromptPosthookPayload + + payload_type = registry.get_payload_type(ToolHookType.TOOL_PRE_INVOKE) + assert payload_type == ToolPreInvokePayload + + def test_get_result_type(self, registry): + """Test retrieving result types from registry.""" + result_type = registry.get_result_type(PromptHookType.PROMPT_PRE_FETCH) + assert result_type == PromptPrehookResult + + result_type = registry.get_result_type(PromptHookType.PROMPT_POST_FETCH) + assert result_type == PromptPosthookResult + + result_type = registry.get_result_type(ToolHookType.TOOL_PRE_INVOKE) + assert result_type == ToolPreInvokeResult + + def test_get_unregistered_hook_returns_none(self, registry): + """Test that unregistered hooks return None.""" + assert registry.get_payload_type("unknown_hook") is None + assert registry.get_result_type("unknown_hook") is None + assert not registry.is_registered("unknown_hook") + + def test_json_to_payload_with_dict(self, registry): + """Test converting dictionary to payload.""" + payload_dict = {"prompt_id": "test", "args": {"key": "value"}} + payload = registry.json_to_payload(PromptHookType.PROMPT_PRE_FETCH, payload_dict) + + assert isinstance(payload, PromptPrehookPayload) + assert payload.prompt_id == "test" + assert payload.args["key"] == "value" + + def test_json_to_payload_with_json_string(self, registry): + """Test converting JSON string to payload.""" + payload_json = '{"prompt_id": "test", "args": {"key": "value"}}' + payload = registry.json_to_payload(PromptHookType.PROMPT_PRE_FETCH, payload_json) + + assert isinstance(payload, PromptPrehookPayload) + assert payload.prompt_id == "test" + assert payload.args["key"] == "value" + + def test_json_to_result_with_dict(self, registry): + """Test converting dictionary to result.""" + result_dict = {"continue_processing": True, "modified_payload": None} + result = registry.json_to_result(PromptHookType.PROMPT_PRE_FETCH, result_dict) + + assert isinstance(result, PromptPrehookResult) + assert result.continue_processing is True + + def test_json_to_result_with_json_string(self, registry): + """Test converting JSON string to result.""" + result_json = '{"continue_processing": false, "modified_payload": null}' + result = registry.json_to_result(PromptHookType.PROMPT_PRE_FETCH, result_json) + + assert isinstance(result, PromptPrehookResult) + assert result.continue_processing is False + + def test_json_to_payload_unregistered_hook_raises_error(self, registry): + """Test that converting payload for unregistered hook raises ValueError.""" + with pytest.raises(ValueError, match="No payload type registered for hook"): + registry.json_to_payload("unknown_hook", {}) + + def test_json_to_result_unregistered_hook_raises_error(self, registry): + """Test that converting result for unregistered hook raises ValueError.""" + with pytest.raises(ValueError, match="No result type registered for hook"): + registry.json_to_result("unknown_hook", {}) + + def test_get_registered_hooks(self, registry): + """Test retrieving all registered hook types.""" + hooks = registry.get_registered_hooks() + + assert isinstance(hooks, list) + assert len(hooks) >= 8 # At least the 6 MCP hooks + assert PromptHookType.PROMPT_PRE_FETCH in hooks + assert PromptHookType.PROMPT_POST_FETCH in hooks + assert ToolHookType.TOOL_PRE_INVOKE in hooks + assert ToolHookType.TOOL_POST_INVOKE in hooks + assert ResourceHookType.RESOURCE_PRE_FETCH in hooks + assert ResourceHookType.RESOURCE_POST_FETCH in hooks + assert AgentHookType.AGENT_POST_INVOKE in hooks + assert AgentHookType.AGENT_PRE_INVOKE in hooks + + def test_registry_is_singleton(self): + """Test that get_hook_registry returns the same instance.""" + registry1 = get_hook_registry() + registry2 = get_hook_registry() + + assert registry1 is registry2 diff --git a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py index 9c7f15174..a0d54bf40 100644 --- a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py +++ b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py @@ -14,13 +14,11 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader -from mcpgateway.plugins.framework.models import GlobalContext, PluginContext, PluginMode, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework import GlobalContext, PluginContext, PluginMode, PromptPosthookPayload, PromptPrehookPayload from plugins.regex_filter.search_replace import SearchReplaceConfig, SearchReplacePlugin -from unittest.mock import patch - def test_config_loader_load(): """pytest for testing the config loader.""" diff --git a/tests/unit/mcpgateway/plugins/framework/test_context.py b/tests/unit/mcpgateway/plugins/framework/test_context.py index f84a94fde..74983f325 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_context.py +++ b/tests/unit/mcpgateway/plugins/framework/test_context.py @@ -11,6 +11,7 @@ from mcpgateway.plugins.framework import ( GlobalContext, PluginManager, + ToolHookType, ToolPreInvokePayload, ToolPostInvokePayload, ) @@ -25,7 +26,7 @@ async def test_shared_context_across_pre_post_hooks(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert len(contexts) == 1 context = next(iter(contexts.values())) @@ -42,7 +43,7 @@ async def test_shared_context_across_pre_post_hooks(): # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, contexts = await manager.tool_post_invoke(tool_result_payload, global_context=global_context, local_contexts=contexts) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) assert len(contexts) == 1 context = next(iter(contexts.values())) @@ -71,7 +72,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert len(contexts) == 2 ctxs = [contexts[key] for key in contexts.keys()] @@ -100,7 +101,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): assert result.modified_payload is None # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, contexts = await manager.tool_post_invoke(tool_result_payload, global_context=global_context, local_contexts=contexts) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) ctxs = [contexts[key] for key in contexts.keys()] assert len(ctxs) == 2 diff --git a/tests/unit/mcpgateway/plugins/framework/test_errors.py b/tests/unit/mcpgateway/plugins/framework/test_errors.py index 9dccc1706..738113453 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_errors.py +++ b/tests/unit/mcpgateway/plugins/framework/test_errors.py @@ -16,7 +16,8 @@ PluginError, PluginMode, PluginManager, - PromptPrehookPayload, + PromptHookType, + PromptPrehookPayload ) @@ -40,7 +41,7 @@ async def test_error_plugin(): global_context = GlobalContext(request_id="1") escaped_regex = re.escape("ValueError('Sadly! Prompt prefetch is broken!')") with pytest.raises(PluginError, match=escaped_regex): - await plugin_manager.prompt_pre_fetch(payload, global_context) + await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) await plugin_manager.shutdown() @@ -51,14 +52,14 @@ async def test_error_plugin_raise_error_false(): payload = PromptPrehookPayload(prompt_id="test_prompt", args={"arg0": "This is a crap argument"}) global_context = GlobalContext(request_id="1") with pytest.raises(PluginError): - result, _ = await plugin_manager.prompt_pre_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) # assert result.continue_processing # assert not result.modified_payload await plugin_manager.shutdown() plugin_manager.config.plugins[0].mode = PluginMode.ENFORCE_IGNORE_ERROR await plugin_manager.initialize() - result, _ = await plugin_manager.prompt_pre_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) assert result.continue_processing assert not result.modified_payload await plugin_manager.shutdown() diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager.py b/tests/unit/mcpgateway/plugins/framework/test_manager.py index 7c58772c1..87144d266 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager.py @@ -11,8 +11,9 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent -from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, PluginManager, PluginViolationError, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.common.models import Message, PromptResult, Role, TextContent +from mcpgateway.plugins.framework import GlobalContext, PluginManager, PluginViolationError +from mcpgateway.plugins.framework import PromptHookType, ToolHookType, HttpHeaderPayload, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload from plugins.regex_filter.search_replace import SearchReplaceConfig @@ -34,7 +35,7 @@ async def test_manager_single_transformer_prompt_plugin(): assert srconfig.words[0].replace == "crud" prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "What a crapshow!"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, contexts = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert len(result.modified_payload.args) == 1 assert result.modified_payload.args["user"] == "What a yikesshow!" @@ -44,7 +45,7 @@ async def test_manager_single_transformer_prompt_plugin(): payload_result = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result, _ = await manager.prompt_post_fetch(payload_result, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_POST_FETCH, payload_result, global_context=global_context, local_contexts=contexts) assert len(result.modified_payload.result.messages) == 1 assert result.modified_payload.result.messages[0].content.text == "What a yikesshow!" await manager.shutdown() @@ -82,7 +83,7 @@ async def test_manager_multiple_transformer_preprompt_plugin(): prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "It's always happy at the crapshow."}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, contexts = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert len(result.modified_payload.args) == 1 assert result.modified_payload.args["user"] == "It's always gleeful at the yikesshow." @@ -92,7 +93,7 @@ async def test_manager_multiple_transformer_preprompt_plugin(): payload_result = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result, _ = await manager.prompt_post_fetch(payload_result, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_POST_FETCH, payload_result, global_context=global_context, local_contexts=contexts) assert len(result.modified_payload.result.messages) == 1 assert result.modified_payload.result.messages[0].content.text == "It's sullen at the yikes bakery." await manager.shutdown() @@ -105,7 +106,7 @@ async def test_manager_no_plugins(): assert manager.initialized prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "It's always happy at the crapshow."}) global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert result.continue_processing assert not result.modified_payload await manager.shutdown() @@ -118,12 +119,12 @@ async def test_manager_filter_plugins(): assert manager.initialized prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "innovative"}) global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert not result.continue_processing assert result.violation with pytest.raises(PluginViolationError) as ve: - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context, violations_as_exceptions=True) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) assert ve.value.violation assert ve.value.violation.reason == "Prompt not allowed" await manager.shutdown() @@ -136,11 +137,11 @@ async def test_manager_multi_filter_plugins(): assert manager.initialized prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "innovative crapshow."}) global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert not result.continue_processing assert result.violation with pytest.raises(PluginViolationError) as ve: - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context, violations_as_exceptions=True) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) assert ve.value.violation await manager.shutdown() @@ -155,7 +156,7 @@ async def test_manager_tool_hooks_empty(): # Test tool pre-invoke with no plugins tool_payload = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with no modifications assert result.continue_processing @@ -165,7 +166,7 @@ async def test_manager_tool_hooks_empty(): # Test tool post-invoke with no plugins tool_result_payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"}) - result, contexts = await manager.tool_post_invoke(tool_result_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context) # Should continue processing with no modifications assert result.continue_processing @@ -186,7 +187,7 @@ async def test_manager_tool_hooks_with_transformer_plugin(): # Test tool pre-invoke - no plugins configured for tool hooks tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is crap data"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with no modifications (no plugins for tool hooks) assert result.continue_processing @@ -196,7 +197,7 @@ async def test_manager_tool_hooks_with_transformer_plugin(): # Test tool post-invoke - no plugins configured for tool hooks tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result with crap in it"}) - result, _ = await manager.tool_post_invoke(tool_result_payload, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) # Should continue processing with no modifications (no plugins for tool hooks) assert result.continue_processing @@ -216,7 +217,7 @@ async def test_manager_tool_hooks_with_actual_plugin(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with transformations applied assert result.continue_processing @@ -228,7 +229,7 @@ async def test_manager_tool_hooks_with_actual_plugin(): # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, _ = await manager.tool_post_invoke(tool_result_payload, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) # Should continue processing with transformations applied assert result.continue_processing @@ -251,7 +252,7 @@ async def test_manager_tool_hooks_with_header_mods(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}, headers=None) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with transformations applied assert result.continue_processing @@ -267,7 +268,7 @@ async def test_manager_tool_hooks_with_header_mods(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}, headers=HttpHeaderPayload({"Content-Type": "application/json"})) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with transformations applied assert result.continue_processing diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py index e8e1d8968..0a7bd317f 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py @@ -16,12 +16,11 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent -from mcpgateway.plugins.framework.base import Plugin +from mcpgateway.common.models import Message, PromptResult, Role, TextContent +from mcpgateway.plugins.framework.base import HookRef, Plugin from mcpgateway.plugins.framework.models import Config from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginCondition, PluginConfig, PluginContext, @@ -31,6 +30,9 @@ PluginResult, PluginViolation, PluginViolationError, + PromptHookType, + ToolHookType, + Plugin, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, @@ -52,7 +54,7 @@ async def prompt_pre_fetch(self, payload, context): # Test with enforce mode manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") await manager.initialize() - manager._pre_prompt_executor.timeout = 0.01 # Set very short timeout + manager._executor.timeout = 0.01 # Set very short timeout # Mock plugin registry plugin_config = PluginConfig( @@ -60,16 +62,16 @@ async def prompt_pre_fetch(self, payload, context): ) timeout_plugin = TimeoutPlugin(plugin_config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(timeout_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(timeout_plugin)) + mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={}) global_context = GlobalContext(request_id="1") escaped_regex = re.escape("Plugin TimeoutPlugin exceeded 0.01s timeout") with pytest.raises(PluginError, match=escaped_regex): - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should pass since fail_on_plugin_error: false # assert result.continue_processing @@ -79,11 +81,11 @@ async def prompt_pre_fetch(self, payload, context): # Test with permissive mode plugin_config.mode = PluginMode.PERMISSIVE - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(timeout_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(timeout_plugin)) + mock_get.return_value = [hook_ref] - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in permissive mode assert result.continue_processing @@ -110,16 +112,16 @@ async def prompt_pre_fetch(self, payload, context): error_plugin = ErrorPlugin(plugin_config) # Test with enforce mode - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(error_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={}) global_context = GlobalContext(request_id="1") escaped_regex = re.escape("RuntimeError('Plugin error!')") with pytest.raises(PluginError, match=escaped_regex): - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should block in enforce mode # assert result.continue_processing @@ -129,44 +131,44 @@ async def prompt_pre_fetch(self, payload, context): # Test with permissive mode plugin_config.mode = PluginMode.PERMISSIVE - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(error_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + mock_get.return_value = [hook_ref] - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in permissive mode assert result.continue_processing assert result.violation is None plugin_config.mode = PluginMode.ENFORCE_IGNORE_ERROR - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(error_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + mock_get.return_value = [hook_ref] - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in enforce_ignore_error mode assert result.continue_processing assert result.violation is None plugin_config.mode = PluginMode.ENFORCE_IGNORE_ERROR - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(error_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + mock_get.return_value = [hook_ref] - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in enforce_ignore_error mode assert result.continue_processing assert result.violation is None plugin_config.mode = PluginMode.ENFORCE_IGNORE_ERROR - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(error_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + mock_get.return_value = [hook_ref] - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in enforce_ignore_error mode assert result.continue_processing @@ -177,17 +179,23 @@ async def prompt_pre_fetch(self, payload, context): @pytest.mark.asyncio async def test_manager_condition_filtering(): - """Test that plugins are filtered based on conditions.""" + """Test that plugins are filtered based on conditions across all hook types.""" + from mcpgateway.plugins.framework import ( + ResourceHookType, + ResourcePreFetchPayload, + AgentHookType, + AgentPreInvokePayload, + ) + + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") + await manager.initialize() + # ========== Test 1: Server ID condition (GlobalContext) ========== class ConditionalPlugin(Plugin): async def prompt_pre_fetch(self, payload, context): payload.args["modified"] = "yes" return PluginResult(continue_processing=True, modified_payload=payload) - manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") - await manager.initialize() - - # Plugin with server_id condition plugin_config = PluginConfig( name="ConditionalPlugin", description="Test conditional plugin", @@ -201,15 +209,16 @@ async def prompt_pre_fetch(self, payload, context): ) plugin = ConditionalPlugin(plugin_config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: plugin_ref = PluginRef(plugin) - mock_get.return_value = [plugin_ref] + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, plugin_ref) + mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={}) # Test with matching server_id global_context = GlobalContext(request_id="1", server_id="server1") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Plugin should execute assert result.continue_processing @@ -219,12 +228,215 @@ async def prompt_pre_fetch(self, payload, context): # Test with non-matching server_id prompt2 = PromptPrehookPayload(prompt_id="test", args={}) global_context2 = GlobalContext(request_id="2", server_id="server2") - result2, _ = await manager.prompt_pre_fetch(prompt2, global_context=global_context2) + result2, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt2, global_context=global_context2) # Plugin should be skipped assert result2.continue_processing assert result2.modified_payload is None # No modification + # ========== Test 2: Prompt-specific filtering ========== + class PromptFilterPlugin(Plugin): + async def prompt_pre_fetch(self, payload, context): + payload.args["prompt_filtered"] = "yes" + return PluginResult(continue_processing=True, modified_payload=payload) + + prompt_plugin_config = PluginConfig( + name="PromptFilterPlugin", + description="Test prompt filtering", + author="Test", + version="1.0", + tags=["test"], + kind="PromptFilterPlugin", + hooks=["prompt_pre_fetch"], + config={}, + conditions=[PluginCondition(prompts={"greeting", "welcome"})], + ) + prompt_plugin = PromptFilterPlugin(prompt_plugin_config) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(prompt_plugin)) + mock_get.return_value = [hook_ref] + + # Test with matching prompt + prompt_match = PromptPrehookPayload(prompt_id="greeting", args={}) + global_context = GlobalContext(request_id="3") + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt_match, global_context=global_context) + + assert result.continue_processing + assert result.modified_payload is not None + assert result.modified_payload.args.get("prompt_filtered") == "yes" + + # Test with non-matching prompt + prompt_no_match = PromptPrehookPayload(prompt_id="other", args={}) + result2, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt_no_match, global_context=global_context) + + assert result2.continue_processing + assert result2.modified_payload is None # Plugin skipped + + # ========== Test 3: Tool filtering ========== + class ToolFilterPlugin(Plugin): + async def tool_pre_invoke(self, payload, context): + payload.args["tool_filtered"] = "yes" + return PluginResult(continue_processing=True, modified_payload=payload) + + tool_plugin_config = PluginConfig( + name="ToolFilterPlugin", + description="Test tool filtering", + author="Test", + version="1.0", + tags=["test"], + kind="ToolFilterPlugin", + hooks=["tool_pre_invoke"], + config={}, + conditions=[PluginCondition(tools={"calculator", "converter"})], + ) + tool_plugin = ToolFilterPlugin(tool_plugin_config) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(ToolHookType.TOOL_PRE_INVOKE, PluginRef(tool_plugin)) + mock_get.return_value = [hook_ref] + + # Test with matching tool + tool_match = ToolPreInvokePayload(name="calculator", args={}) + global_context = GlobalContext(request_id="4") + result, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_match, global_context=global_context) + + assert result.continue_processing + assert result.modified_payload is not None + assert result.modified_payload.args.get("tool_filtered") == "yes" + + # Test with non-matching tool + tool_no_match = ToolPreInvokePayload(name="other_tool", args={}) + result2, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_no_match, global_context=global_context) + + assert result2.continue_processing + assert result2.modified_payload is None # Plugin skipped + + # ========== Test 4: Resource filtering ========== + class ResourceFilterPlugin(Plugin): + async def resource_pre_fetch(self, payload, context): + payload.metadata["resource_filtered"] = "yes" + return PluginResult(continue_processing=True, modified_payload=payload) + + resource_plugin_config = PluginConfig( + name="ResourceFilterPlugin", + description="Test resource filtering", + author="Test", + version="1.0", + tags=["test"], + kind="ResourceFilterPlugin", + hooks=["resource_pre_fetch"], + config={}, + conditions=[PluginCondition(resources={"file:///data.txt", "file:///config.json"})], + ) + resource_plugin = ResourceFilterPlugin(resource_plugin_config) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(ResourceHookType.RESOURCE_PRE_FETCH, PluginRef(resource_plugin)) + mock_get.return_value = [hook_ref] + + # Test with matching resource + resource_match = ResourcePreFetchPayload(uri="file:///data.txt", metadata={}) + global_context = GlobalContext(request_id="5") + result, _ = await manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, resource_match, global_context=global_context) + + assert result.continue_processing + assert result.modified_payload is not None + assert result.modified_payload.metadata.get("resource_filtered") == "yes" + + # Test with non-matching resource + resource_no_match = ResourcePreFetchPayload(uri="file:///other.txt", metadata={}) + result2, _ = await manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, resource_no_match, global_context=global_context) + + assert result2.continue_processing + assert result2.modified_payload is None # Plugin skipped + + # ========== Test 5: Agent filtering ========== + class AgentFilterPlugin(Plugin): + async def agent_pre_invoke(self, payload, context): + payload.parameters["agent_filtered"] = "yes" + return PluginResult(continue_processing=True, modified_payload=payload) + + agent_plugin_config = PluginConfig( + name="AgentFilterPlugin", + description="Test agent filtering", + author="Test", + version="1.0", + tags=["test"], + kind="AgentFilterPlugin", + hooks=["agent_pre_invoke"], + config={}, + conditions=[PluginCondition(agents={"agent1", "agent2"})], + ) + agent_plugin = AgentFilterPlugin(agent_plugin_config) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(AgentHookType.AGENT_PRE_INVOKE, PluginRef(agent_plugin)) + mock_get.return_value = [hook_ref] + + # Test with matching agent + agent_match = AgentPreInvokePayload(agent_id="agent1", messages=[], parameters={}) + global_context = GlobalContext(request_id="6") + result, _ = await manager.invoke_hook(AgentHookType.AGENT_PRE_INVOKE, agent_match, global_context=global_context) + + assert result.continue_processing + assert result.modified_payload is not None + assert result.modified_payload.parameters.get("agent_filtered") == "yes" + + # Test with non-matching agent + agent_no_match = AgentPreInvokePayload(agent_id="agent3", messages=[], parameters={}) + result2, _ = await manager.invoke_hook(AgentHookType.AGENT_PRE_INVOKE, agent_no_match, global_context=global_context) + + assert result2.continue_processing + assert result2.modified_payload is None # Plugin skipped + + # ========== Test 6: Combined conditions (server_id + tool name) ========== + class CombinedFilterPlugin(Plugin): + async def tool_pre_invoke(self, payload, context): + payload.args["combined_filtered"] = "yes" + return PluginResult(continue_processing=True, modified_payload=payload) + + combined_plugin_config = PluginConfig( + name="CombinedFilterPlugin", + description="Test combined filtering", + author="Test", + version="1.0", + tags=["test"], + kind="CombinedFilterPlugin", + hooks=["tool_pre_invoke"], + config={}, + conditions=[PluginCondition(server_ids={"server1"}, tools={"calculator"})], + ) + combined_plugin = CombinedFilterPlugin(combined_plugin_config) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(ToolHookType.TOOL_PRE_INVOKE, PluginRef(combined_plugin)) + mock_get.return_value = [hook_ref] + + # Test with both conditions matching + tool_payload = ToolPreInvokePayload(name="calculator", args={}) + global_context = GlobalContext(request_id="7", server_id="server1") + result, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) + + assert result.continue_processing + assert result.modified_payload is not None + assert result.modified_payload.args.get("combined_filtered") == "yes" + + # Test with server_id mismatch + global_context2 = GlobalContext(request_id="8", server_id="server2") + result2, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context2) + + assert result2.continue_processing + assert result2.modified_payload is None # Plugin skipped + + # Test with tool name mismatch + tool_payload2 = ToolPreInvokePayload(name="other_tool", args={}) + global_context3 = GlobalContext(request_id="9", server_id="server1") + result3, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload2, global_context=global_context3) + + assert result3.continue_processing + assert result3.modified_payload is None # Plugin skipped + await manager.shutdown() @@ -251,14 +463,14 @@ async def prompt_pre_fetch(self, payload, context): plugin1 = MetadataPlugin1(config1) plugin2 = MetadataPlugin2(config2) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - refs = [PluginRef(plugin1), PluginRef(plugin2)] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + refs = [HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(plugin1)), HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(plugin2))] mock_get.return_value = refs prompt = PromptPrehookPayload(prompt_id="test", args={}) global_context = GlobalContext(request_id="1") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should aggregate metadata assert result.continue_processing @@ -292,17 +504,25 @@ async def prompt_post_fetch(self, payload, context: PluginContext): ) plugin = StatefulPlugin(config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_pre, patch.object(manager._registry, "get_plugins_for_hook") as mock_post: - plugin_ref = PluginRef(plugin) + # Create a single PluginRef to ensure the same UUID is used for both hooks + plugin_ref = PluginRef(plugin) + hook_ref_pre = HookRef(PromptHookType.PROMPT_PRE_FETCH, plugin_ref) + hook_ref_post = HookRef(PromptHookType.PROMPT_POST_FETCH, plugin_ref) - mock_pre.return_value = [plugin_ref] - mock_post.return_value = [plugin_ref] + def get_hook_refs_side_effect(hook_type): + if hook_type == PromptHookType.PROMPT_PRE_FETCH: + return [hook_ref_pre] + elif hook_type == PromptHookType.PROMPT_POST_FETCH: + return [hook_ref_post] + return [] + + with patch.object(manager._registry, "get_hook_refs_for_hook", side_effect=get_hook_refs_side_effect): # First call to pre_fetch prompt = PromptPrehookPayload(prompt_id="test", args={}) global_context = GlobalContext(request_id="1") - result_pre, contexts = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result_pre, contexts = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert result_pre.continue_processing # Call to post_fetch with same contexts @@ -310,7 +530,7 @@ async def prompt_post_fetch(self, payload, context: PluginContext): prompt_result = PromptResult(messages=[message]) post_payload = PromptPosthookPayload(prompt_id="test", result=prompt_result) - result_post, _ = await manager.prompt_post_fetch(post_payload, global_context=global_context, local_contexts=contexts) + result_post, _ = await manager.invoke_hook(PromptHookType.PROMPT_POST_FETCH, post_payload, global_context=global_context, local_contexts=contexts) # Should have modified with persisted state assert result_post.continue_processing @@ -337,14 +557,14 @@ async def prompt_pre_fetch(self, payload, context): ) plugin = BlockingPlugin(config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(plugin)) + mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={"text": "bad content"}) global_context = GlobalContext(request_id="1") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should block the request assert not result.continue_processing @@ -353,7 +573,7 @@ async def prompt_pre_fetch(self, payload, context): assert result.violation.plugin_name == "BlockingPlugin" with pytest.raises(PluginViolationError) as pve: - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context, violations_as_exceptions=True) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) assert pve.value.violation assert pve.value.message assert pve.value.violation.code == "CONTENT_BLOCKED" @@ -387,14 +607,14 @@ async def prompt_pre_fetch(self, payload, context): plugin = BlockingPlugin(config) # Test permissive mode blocking (covers lines 194-195) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(plugin)) + mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={"text": "content"}) global_context = GlobalContext(request_id="1") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in permissive mode - the permissive logic continues without blocking assert result.continue_processing @@ -434,10 +654,10 @@ async def test_manager_payload_size_validation(): """Test payload size validation functionality.""" # First-Party from mcpgateway.plugins.framework.manager import MAX_PAYLOAD_SIZE, PayloadSizeError, PluginExecutor - from mcpgateway.plugins.framework.models import PromptPosthookPayload, PromptPrehookPayload + from mcpgateway.plugins.framework import PromptPosthookPayload, PromptPrehookPayload # Test payload size validation directly on executor (covers lines 252, 258) - executor = PluginExecutor[PromptPrehookPayload]() + executor = PluginExecutor() # Test large args payload (covers line 252) large_data = "x" * (MAX_PAYLOAD_SIZE + 1) @@ -449,7 +669,7 @@ async def test_manager_payload_size_validation(): # Test large result payload (covers line 258) # First-Party - from mcpgateway.models import Message, PromptResult, Role, TextContent + from mcpgateway.common.models import Message, PromptResult, Role, TextContent large_text = "y" * (MAX_PAYLOAD_SIZE + 1) message = Message(role=Role.USER, content=TextContent(type="text", text=large_text)) @@ -457,7 +677,7 @@ async def test_manager_payload_size_validation(): large_post_payload = PromptPosthookPayload(prompt_id="test", result=large_result) # Should raise PayloadSizeError for large result - executor2 = PluginExecutor[PromptPosthookPayload]() + executor2 = PluginExecutor() with pytest.raises(PayloadSizeError, match="Result size .* exceeds limit"): executor2._validate_payload_size(large_post_payload) @@ -492,7 +712,7 @@ async def test_manager_initialization_edge_cases(): tags=["test"], kind="nonexistent.Plugin", mode=PluginMode.ENFORCE, - hooks=[HookType.PROMPT_PRE_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}, ) ], @@ -516,7 +736,7 @@ async def test_manager_initialization_edge_cases(): tags=["test"], kind="test.Plugin", mode=PluginMode.DISABLED, # Disabled mode - hooks=[HookType.PROMPT_PRE_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}, ) ], @@ -527,72 +747,19 @@ async def test_manager_initialization_edge_cases(): await manager2.shutdown() -@pytest.mark.asyncio -async def test_manager_context_cleanup(): - """Test context cleanup functionality.""" - # Standard - import time - - # First-Party - from mcpgateway.plugins.framework.manager import CONTEXT_MAX_AGE - - manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") - await manager.initialize() - - # Add some old contexts to the store - old_time = time.time() - CONTEXT_MAX_AGE - 1 # Older than max age - manager._context_store["old_request"] = ({}, old_time) - manager._context_store["new_request"] = ({}, time.time()) - - # Force cleanup by setting last cleanup time to 0 - manager._last_cleanup = 0 - - with patch("mcpgateway.plugins.framework.manager.logger") as mock_logger: - # Run cleanup (covers lines 551, 554) - await manager._cleanup_old_contexts() - - # Should have removed old context - assert "old_request" not in manager._context_store - assert "new_request" in manager._context_store - - # Should log cleanup message - mock_logger.info.assert_called_with("Cleaned up 1 expired plugin contexts") - - await manager.shutdown() - - -@pytest.mark.asyncio -async def test_manager_constructor_context_init(): - """Test manager constructor context initialization.""" - - # Test that managers share state and context store exists (covers lines 432-433) - manager1 = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") - manager2 = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") - - # Both managers should share the same state - assert hasattr(manager1, "_context_store") - assert hasattr(manager2, "_context_store") - assert hasattr(manager1, "_last_cleanup") - assert hasattr(manager2, "_last_cleanup") - - # They should be the same instance due to shared state - assert manager1._context_store is manager2._context_store - await manager1.shutdown() - await manager2.shutdown() - - @pytest.mark.asyncio async def test_base_plugin_coverage(): """Test base plugin functionality for complete coverage.""" # First-Party - from mcpgateway.models import Message, PromptResult, Role, TextContent - from mcpgateway.plugins.framework.base import Plugin, PluginRef - from mcpgateway.plugins.framework.models import ( + from mcpgateway.common.models import Message, PromptResult, Role, TextContent + from mcpgateway.plugins.framework.base import PluginRef + from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, PluginMode, + PromptHookType, + ToolHookType, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, @@ -607,7 +774,7 @@ async def test_base_plugin_coverage(): version="1.0", tags=["test", "coverage"], # Tags to be accessed kind="test.Plugin", - hooks=[HookType.PROMPT_PRE_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}, ) @@ -627,7 +794,7 @@ async def test_base_plugin_coverage(): context = PluginContext(global_context=GlobalContext(request_id="test")) payload = PromptPrehookPayload(prompt_id="test", args={}) - with pytest.raises(NotImplementedError, match="'prompt_pre_fetch' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'prompt_pre_fetch'"): await plugin.prompt_pre_fetch(payload, context) # Test NotImplementedError for prompt_post_fetch (covers lines 167-171) @@ -635,17 +802,17 @@ async def test_base_plugin_coverage(): result = PromptResult(messages=[message]) post_payload = PromptPosthookPayload(prompt_id="test", result=result) - with pytest.raises(NotImplementedError, match="'prompt_post_fetch' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'prompt_post_fetch'"): await plugin.prompt_post_fetch(post_payload, context) # Test default tool_pre_invoke implementation (covers line 191) tool_payload = ToolPreInvokePayload(name="test_tool", args={"key": "value"}) - with pytest.raises(NotImplementedError, match="'tool_pre_invoke' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'tool_pre_invoke'"): await plugin.tool_pre_invoke(tool_payload, context) # Test default tool_post_invoke implementation (covers line 211) tool_post_payload = ToolPostInvokePayload(name="test_tool", result={"result": "success"}) - with pytest.raises(NotImplementedError, match="'tool_post_invoke' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'tool_post_invoke'"): await plugin.tool_post_invoke(tool_post_payload, context) @@ -690,12 +857,12 @@ async def test_plugin_loader_return_none(): """Test plugin loader return None case.""" # First-Party from mcpgateway.plugins.framework.loader.plugin import PluginLoader - from mcpgateway.plugins.framework.models import HookType, PluginConfig + from mcpgateway.plugins.framework import PluginConfig loader = PluginLoader() # Test return None when plugin_type is None (covers line 90) - config = PluginConfig(name="TestPlugin", description="Test", author="Test", version="1.0", tags=["test"], kind="test.plugin.TestPlugin", hooks=[HookType.PROMPT_PRE_FETCH], config={}) + config = PluginConfig(name="TestPlugin", description="Test", author="Test", version="1.0", tags=["test"], kind="test.plugin.TestPlugin", hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}) # Mock the plugin_types dict to contain None for this kind loader._plugin_types[config.kind] = None @@ -753,20 +920,20 @@ async def tool_pre_invoke(self, payload, context): ) plugin = TestPlugin(config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(ToolHookType.TOOL_PRE_INVOKE, PluginRef(plugin)) + mock_get.return_value = [hook_ref] # Test with matching tool tool_payload = ToolPreInvokePayload(name="calculator", args={}) global_context = GlobalContext(request_id="1") - result, _ = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert result.continue_processing # Test with non-matching tool tool_payload2 = ToolPreInvokePayload(name="other_tool", args={}) - result2, _ = await manager.tool_pre_invoke(tool_payload2, global_context=global_context) + result2, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload2, global_context=global_context) assert result2.continue_processing await manager.shutdown() @@ -786,14 +953,14 @@ async def tool_post_invoke(self, payload, context): config = PluginConfig(name="ModifyingPlugin", description="Test modifying plugin", author="Test", version="1.0", tags=["test"], kind="ModifyingPlugin", hooks=["tool_post_invoke"], config={}) plugin = ModifyingPlugin(config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(ToolHookType.TOOL_POST_INVOKE, PluginRef(plugin)) + mock_get.return_value = [hook_ref] tool_payload = ToolPostInvokePayload(name="test_tool", result={"original": "data"}) global_context = GlobalContext(request_id="1") - result, _ = await manager.tool_post_invoke(tool_payload, global_context=global_context) + result, _ = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_payload, global_context=global_context) assert result.continue_processing assert result.modified_payload is not None diff --git a/tests/unit/mcpgateway/plugins/framework/test_registry.py b/tests/unit/mcpgateway/plugins/framework/test_registry.py index 7f62b694f..64fa9e009 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_registry.py +++ b/tests/unit/mcpgateway/plugins/framework/test_registry.py @@ -14,11 +14,11 @@ import pytest # First-Party -from mcpgateway.plugins.framework.base import Plugin from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader -from mcpgateway.plugins.framework.models import HookType, PluginConfig +from mcpgateway.plugins.framework import PluginConfig, Plugin, PromptHookType, ToolHookType from mcpgateway.plugins.framework.registry import PluginInstanceRegistry +from tests.unit.mcpgateway.plugins.fixtures.plugins.simple import SimplePromptPlugin @pytest.mark.asyncio @@ -78,7 +78,7 @@ async def test_registry_priority_sorting(): version="1.0", tags=["test"], kind="test.Plugin", - hooks=[HookType.PROMPT_PRE_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH], priority=300, # High number = low priority config={}, ) @@ -90,27 +90,27 @@ async def test_registry_priority_sorting(): version="1.0", tags=["test"], kind="test.Plugin", - hooks=[HookType.PROMPT_PRE_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH], priority=50, # Low number = high priority config={}, ) # Create plugin instances - low_priority_plugin = Plugin(low_priority_config) - high_priority_plugin = Plugin(high_priority_config) + low_priority_plugin = SimplePromptPlugin(low_priority_config) + high_priority_plugin = SimplePromptPlugin(high_priority_config) # Register plugins in reverse priority order registry.register(low_priority_plugin) registry.register(high_priority_plugin) # Get plugins for hook - should be sorted by priority (lines 131-134) - hook_plugins = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + hook_plugins = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) assert len(hook_plugins) == 2 - assert hook_plugins[0].name == "HighPriority" # Lower number = higher priority - assert hook_plugins[1].name == "LowPriority" + assert hook_plugins[0].plugin_ref.name == "HighPriority" # Lower number = higher priority + assert hook_plugins[1].plugin_ref.name == "LowPriority" # Test priority cache - calling again should use cached result - cached_plugins = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + cached_plugins = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) assert cached_plugins == hook_plugins # Clean up @@ -126,29 +126,29 @@ async def test_registry_hook_filtering(): # Create plugin with specific hooks pre_fetch_config = PluginConfig( - name="PreFetchPlugin", description="Pre-fetch plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={} + name="PreFetchPlugin", description="Pre-fetch plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_PRE_FETCH], config={} ) post_fetch_config = PluginConfig( - name="PostFetchPlugin", description="Post-fetch plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_POST_FETCH], config={} + name="PostFetchPlugin", description="Post-fetch plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_POST_FETCH], config={} ) - pre_fetch_plugin = Plugin(pre_fetch_config) - post_fetch_plugin = Plugin(post_fetch_config) + pre_fetch_plugin = SimplePromptPlugin(pre_fetch_config) + post_fetch_plugin = SimplePromptPlugin(post_fetch_config) registry.register(pre_fetch_plugin) registry.register(post_fetch_plugin) # Test hook filtering - pre_plugins = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) - post_plugins = registry.get_plugins_for_hook(HookType.PROMPT_POST_FETCH) - tool_plugins = registry.get_plugins_for_hook(HookType.TOOL_PRE_INVOKE) + pre_plugins = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) + post_plugins = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_POST_FETCH) + tool_plugins = registry.get_hook_refs_for_hook(ToolHookType.TOOL_PRE_INVOKE) assert len(pre_plugins) == 1 - assert pre_plugins[0].name == "PreFetchPlugin" + assert pre_plugins[0].plugin_ref.name == "PreFetchPlugin" assert len(post_plugins) == 1 - assert post_plugins[0].name == "PostFetchPlugin" + assert post_plugins[0].plugin_ref.name == "PostFetchPlugin" assert len(tool_plugins) == 0 # No plugins for this hook @@ -163,9 +163,9 @@ async def test_registry_shutdown(): registry = PluginInstanceRegistry() # Create mock plugins with shutdown methods - mock_plugin1 = Plugin(PluginConfig(name="Plugin1", description="Test plugin 1", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={})) + mock_plugin1 = SimplePromptPlugin(PluginConfig(name="Plugin1", description="Test plugin 1", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_PRE_FETCH], config={})) - mock_plugin2 = Plugin(PluginConfig(name="Plugin2", description="Test plugin 2", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_POST_FETCH], config={})) + mock_plugin2 = SimplePromptPlugin(PluginConfig(name="Plugin2", description="Test plugin 2", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_POST_FETCH], config={})) # Mock the shutdown methods mock_plugin1.shutdown = AsyncMock() @@ -196,8 +196,8 @@ async def test_registry_shutdown_with_error(): registry = PluginInstanceRegistry() # Create mock plugin that fails during shutdown - failing_plugin = Plugin( - PluginConfig(name="FailingPlugin", description="Plugin that fails shutdown", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={}) + failing_plugin = SimplePromptPlugin( + PluginConfig(name="FailingPlugin", description="Plugin that fails shutdown", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}) ) # Mock shutdown to raise an exception @@ -232,7 +232,7 @@ async def test_registry_edge_cases(): assert registry.plugin_count == 0 # Test getting hooks for empty registry - empty_hooks = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + empty_hooks = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) assert len(empty_hooks) == 0 # Test get_all_plugins when empty @@ -244,23 +244,23 @@ async def test_registry_cache_invalidation(): """Test that priority cache is invalidated correctly.""" registry = PluginInstanceRegistry() - plugin_config = PluginConfig(name="TestPlugin", description="Test plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={}) + plugin_config = PluginConfig(name="TestPlugin", description="Test plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}) - plugin = Plugin(plugin_config) + plugin = SimplePromptPlugin(plugin_config) # Register plugin registry.register(plugin) # Get plugins for hook (populates cache) - hooks1 = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + hooks1 = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) assert len(hooks1) == 1 # Cache should be populated - assert HookType.PROMPT_PRE_FETCH in registry._priority_cache + assert PromptHookType.PROMPT_PRE_FETCH in registry._priority_cache # Unregister plugin (should invalidate cache) registry.unregister("TestPlugin") # Cache should be cleared for this hook type - hooks2 = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + hooks2 = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) assert len(hooks2) == 0 diff --git a/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py b/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py index 1a3dbcb67..b783ec45f 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py +++ b/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py @@ -14,13 +14,12 @@ import pytest # First-Party -from mcpgateway.models import ResourceContent -from mcpgateway.plugins.framework.base import Plugin, PluginRef +from mcpgateway.common.models import ResourceContent +from mcpgateway.plugins.framework.base import PluginRef # Registry is imported for mocking from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginCondition, PluginConfig, PluginContext, @@ -28,6 +27,8 @@ PluginManager, PluginMode, PluginViolation, + ResourceHookType, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, @@ -61,14 +62,14 @@ async def test_plugin_resource_pre_fetch_default(self): author="test", kind="test.Plugin", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["test"], ) plugin = Plugin(config) payload = ResourcePreFetchPayload(uri="file:///test.txt", metadata={}) context = PluginContext(global_context=GlobalContext(request_id="test-123")) - with pytest.raises(NotImplementedError, match="'resource_pre_fetch' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'resource_pre_fetch'"): await plugin.resource_pre_fetch(payload, context) @pytest.mark.asyncio @@ -80,7 +81,7 @@ async def test_plugin_resource_post_fetch_default(self): author="test", kind="test.Plugin", version="1.0.0", - hooks=[HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_POST_FETCH], tags=["test"], ) plugin = Plugin(config) @@ -88,7 +89,7 @@ async def test_plugin_resource_post_fetch_default(self): payload = ResourcePostFetchPayload(uri="file:///test.txt", content=content) context = PluginContext(global_context=GlobalContext(request_id="test-123")) - with pytest.raises(NotImplementedError, match="'resource_post_fetch' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'resource_post_fetch'"): await plugin.resource_post_fetch(payload, context) @pytest.mark.asyncio @@ -113,7 +114,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.BlockingPlugin", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["test"], mode=PluginMode.ENFORCE, ) @@ -157,7 +158,7 @@ async def resource_post_fetch(self, payload, context): author="test", kind="test.FilterPlugin", version="1.0.0", - hooks=[HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_POST_FETCH], tags=["filter"], ) plugin = ContentFilterPlugin(config) @@ -198,7 +199,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.ConditionalPlugin", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["conditional"], conditions=[ PluginCondition( @@ -273,58 +274,52 @@ async def test_manager_resource_pre_fetch(self): payload = ResourcePreFetchPayload(uri="test://resource", metadata={}) global_context = GlobalContext(request_id="test-123") - result, contexts = await manager.resource_pre_fetch(payload, global_context) + result, contexts = await manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, payload, global_context) assert result.continue_processing is True - MockRegistry.return_value.get_plugins_for_hook.assert_called_with(HookType.RESOURCE_PRE_FETCH) + MockRegistry.return_value.get_hook_refs_for_hook.assert_called_with(hook_type=ResourceHookType.RESOURCE_PRE_FETCH) @pytest.mark.asyncio async def test_manager_resource_post_fetch(self): """Test plugin manager resource_post_fetch execution.""" - with patch("mcpgateway.plugins.framework.manager.PluginInstanceRegistry") as MockRegistry: - with patch("mcpgateway.plugins.framework.loader.config.ConfigLoader.load_config") as MockConfig: - # Create a proper mock plugin with all required attributes - mock_plugin_obj = MagicMock() - mock_plugin_obj.name = "test_plugin" - mock_plugin_obj.priority = 50 - mock_plugin_obj.mode = PluginMode.ENFORCE - mock_plugin_obj.conditions = [] - mock_plugin_obj.resource_post_fetch = AsyncMock( - return_value=ResourcePostFetchResult( - continue_processing=True, - modified_payload=None, - ) - ) + # First-Party + from mcpgateway.plugins.framework.base import HookRef - # Create a PluginRef-like mock - mock_ref = MagicMock() - mock_ref._plugin = mock_plugin_obj - mock_ref.plugin = mock_plugin_obj - mock_ref.name = "test_plugin" - mock_ref.priority = 50 - mock_ref.mode = PluginMode.ENFORCE - mock_ref.conditions = [] - mock_ref.uuid = "test-uuid" + class TestResourcePlugin(Plugin): + async def resource_post_fetch(self, payload, context): + return ResourcePostFetchResult( + continue_processing=True, + modified_payload=None, + ) - MockRegistry.return_value.get_plugins_for_hook.return_value = [mock_ref] + config = PluginConfig( + name="test_plugin", + description="Test resource plugin", + author="test", + kind="test.Plugin", + version="1.0.0", + hooks=[ResourceHookType.RESOURCE_POST_FETCH], + tags=["test"], + mode=PluginMode.ENFORCE, + ) + plugin = TestResourcePlugin(config) + plugin_ref = PluginRef(plugin) + hook_ref = HookRef(ResourceHookType.RESOURCE_POST_FETCH, plugin_ref) - # Mock config - mock_config = MagicMock() - mock_config.plugin_settings = MagicMock() - MockConfig.return_value = mock_config + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") + await manager.initialize() - manager = PluginManager("test_config.yaml") - manager._registry = MockRegistry.return_value - manager._initialized = True + with patch.object(manager._registry, "get_hook_refs_for_hook", return_value=[hook_ref]): + content = ResourceContent(type="resource", id="123", uri="test://resource", text="Test") + payload = ResourcePostFetchPayload(uri="test://resource", content=content) + global_context = GlobalContext(request_id="test-123") - content = ResourceContent(type="resource", id="123", uri="test://resource", text="Test") - payload = ResourcePostFetchPayload(uri="test://resource", content=content) - global_context = GlobalContext(request_id="test-123") + result, contexts = await manager.invoke_hook(ResourceHookType.RESOURCE_POST_FETCH, payload, global_context, {}) - result, contexts = await manager.resource_post_fetch(payload, global_context, {}) + assert result.continue_processing is True + manager._registry.get_hook_refs_for_hook.assert_called_with(hook_type=ResourceHookType.RESOURCE_POST_FETCH) - assert result.continue_processing is True - MockRegistry.return_value.get_plugins_for_hook.assert_called_with(HookType.RESOURCE_POST_FETCH) + await manager.shutdown() @pytest.mark.asyncio async def test_resource_hook_chain_execution(self): @@ -355,7 +350,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.First", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["test"], priority=10, # Higher priority ) @@ -365,7 +360,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.Second", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["test"], priority=20, # Lower priority ) @@ -383,6 +378,8 @@ async def resource_pre_fetch(self, payload, context): @pytest.mark.asyncio async def test_resource_hook_error_handling(self): """Test resource hook error handling.""" + # First-Party + from mcpgateway.plugins.framework.base import HookRef class ErrorPlugin(Plugin): async def resource_pre_fetch(self, payload, context): @@ -394,46 +391,32 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.ErrorPlugin", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["test"], mode=PluginMode.PERMISSIVE, # Should continue on error ) plugin = ErrorPlugin(config) + plugin_ref = PluginRef(plugin) + hook_ref = HookRef(ResourceHookType.RESOURCE_PRE_FETCH, plugin_ref) - with patch("mcpgateway.plugins.framework.manager.PluginInstanceRegistry") as MockRegistry: - with patch("mcpgateway.plugins.framework.loader.config.ConfigLoader.load_config") as MockConfig: - # Create a proper mock ref - mock_ref = MagicMock() - mock_ref._plugin = plugin - mock_ref.plugin = plugin - mock_ref.name = "error_plugin" - mock_ref.priority = 100 - mock_ref.mode = PluginMode.PERMISSIVE - mock_ref.conditions = [] - mock_ref.uuid = "test-uuid" - - MockRegistry.return_value.get_plugins_for_hook.return_value = [mock_ref] + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") + await manager.initialize() - # Mock config - mock_config = MagicMock() - mock_config.plugin_settings = MagicMock() - mock_config.plugin_settings.fail_on_plugin_error = False - MockConfig.return_value = mock_config + payload = ResourcePreFetchPayload(uri="test://resource", metadata={}) + global_context = GlobalContext(request_id="test-123") - manager = PluginManager("test_config.yaml") - manager._registry = MockRegistry.return_value - manager._initialized = True + # Test with permissive mode - should handle error gracefully + with patch.object(manager._registry, "get_hook_refs_for_hook", return_value=[hook_ref]): + result, contexts = await manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, payload, global_context) + assert result.continue_processing is True # Continues despite error - payload = ResourcePreFetchPayload(uri="test://resource", metadata={}) - global_context = GlobalContext(request_id="test-123") - # Should handle error gracefully when fail_on_plugin_error = False - result, contexts = await manager.resource_pre_fetch(payload, global_context) - assert result.continue_processing is True # Continues despite error + # Test with enforce mode - should raise PluginError + config.mode = PluginMode.ENFORCE + with patch.object(manager._registry, "get_hook_refs_for_hook", return_value=[hook_ref]): + with pytest.raises(PluginError): + result, contexts = await manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, payload, global_context) - mock_config.plugin_settings.fail_on_plugin_error = True - # Should throw a plugin error since fail_on_plugin_error = True - with pytest.raises(PluginError): - result, contexts = await manager.resource_pre_fetch(payload, global_context) + await manager.shutdown() @pytest.mark.asyncio async def test_resource_uri_modification(self): @@ -457,7 +440,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.URIModifier", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["modifier"], ) plugin = URIModifierPlugin(config) @@ -491,7 +474,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.Enricher", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["enricher"], ) plugin = MetadataEnricherPlugin(config) diff --git a/tests/unit/mcpgateway/plugins/framework/test_utils.py b/tests/unit/mcpgateway/plugins/framework/test_utils.py index 82b303417..7b27626b0 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_utils.py +++ b/tests/unit/mcpgateway/plugins/framework/test_utils.py @@ -11,50 +11,58 @@ import sys # First-Party -from mcpgateway.plugins.framework.models import GlobalContext, PluginCondition, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload -from mcpgateway.plugins.framework.utils import import_module, matches, parse_class_name, post_prompt_matches, post_tool_matches, pre_prompt_matches, pre_tool_matches +from mcpgateway.plugins.framework import ( + GlobalContext, + PluginCondition, + PromptPrehookPayload, + PromptPosthookPayload, + ToolPreInvokePayload, + ToolPostInvokePayload, +) +from mcpgateway.plugins.framework.utils import import_module, matches, parse_class_name, payload_matches def test_server_ids(): + """Test conditional matching with server IDs, tenant IDs, and user patterns.""" condition1 = PluginCondition(server_ids={"1", "2"}) context1 = GlobalContext(server_id="1", tenant_id="4", request_id="5") payload1 = PromptPrehookPayload(prompt_id="test_prompt", args={}) assert matches(condition=condition1, context=context1) - assert pre_prompt_matches(payload1, [condition1], context1) + assert payload_matches(payload1, "prompt_pre_fetch", [condition1], context1) context2 = GlobalContext(server_id="3", tenant_id="6", request_id="1") assert not matches(condition=condition1, context=context2) - assert not pre_prompt_matches(payload1, conditions=[condition1], context=context2) + assert not payload_matches(payload1, "prompt_pre_fetch", [condition1], context2) condition2 = PluginCondition(server_ids={"1"}, tenant_ids={"4"}) context2 = GlobalContext(server_id="1", tenant_id="4", request_id="1") assert matches(condition2, context2) - assert pre_prompt_matches(payload1, conditions=[condition2], context=context2) + assert payload_matches(payload1, "prompt_pre_fetch", [condition2], context2) context3 = GlobalContext(server_id="1", tenant_id="5", request_id="1") assert not matches(condition2, context3) - assert not pre_prompt_matches(payload1, conditions=[condition2], context=context3) + assert not payload_matches(payload1, "prompt_pre_fetch", [condition2], context3) condition4 = PluginCondition(user_patterns=["blah", "barker", "bobby"]) context4 = GlobalContext(user="blah", request_id="1") assert matches(condition4, context4) - assert pre_prompt_matches(payload1, conditions=[condition4], context=context4) + assert payload_matches(payload1, "prompt_pre_fetch", [condition4], context4) context5 = GlobalContext(user="barney", request_id="1") assert not matches(condition4, context5) - assert not pre_prompt_matches(payload1, conditions=[condition4], context=context5) + assert not payload_matches(payload1, "prompt_pre_fetch", [condition4], context5) condition5 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt"}) - assert pre_prompt_matches(payload1, [condition5], context1) + assert payload_matches(payload1, "prompt_pre_fetch", [condition5], context1) condition6 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt2"}) - assert not pre_prompt_matches(payload1, [condition6], context1) + assert not payload_matches(payload1, "prompt_pre_fetch", [condition6], context1) # ============================================================================ @@ -106,98 +114,87 @@ def test_parse_class_name(): # ============================================================================ -# Test post_prompt_matches function +# Test payload_matches for prompt hooks # ============================================================================ -def test_post_prompt_matches(): - """Test the post_prompt_matches function.""" - # Import required models - # First-Party - from mcpgateway.models import Message, PromptResult, TextContent - +def test_payload_matches_prompt_post_fetch(): + """Test payload_matches for prompt_post_fetch hook.""" # Test basic matching - msg = Message(role="assistant", content=TextContent(type="text", text="Hello")) - result = PromptResult(messages=[msg]) - payload = PromptPosthookPayload(prompt_id="greeting", result=result) + payload = PromptPosthookPayload(prompt_id="greeting", result={"messages": []}) condition = PluginCondition(prompts={"greeting"}) context = GlobalContext(request_id="req1") - assert post_prompt_matches(payload, [condition], context) is True + assert payload_matches(payload, "prompt_post_fetch", [condition], context) is True # Test no match - payload2 = PromptPosthookPayload(prompt_id ="other", result=result) - assert post_prompt_matches(payload2, [condition], context) is False + payload2 = PromptPosthookPayload(prompt_id="other", result={"messages": []}) + assert payload_matches(payload2, "prompt_post_fetch", [condition], context) is False # Test with server_id condition condition_with_server = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) context_with_server = GlobalContext(request_id="req1", server_id="srv1") - assert post_prompt_matches(payload, [condition_with_server], context_with_server) is True + assert payload_matches(payload, "prompt_post_fetch", [condition_with_server], context_with_server) is True # Test with mismatched server_id context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") - assert post_prompt_matches(payload, [condition_with_server], context_wrong_server) is False - + assert payload_matches(payload, "prompt_post_fetch", [condition_with_server], context_wrong_server) is False -def test_post_prompt_matches_multiple_conditions(): - """Test post_prompt_matches with multiple conditions (OR logic).""" - # First-Party - from mcpgateway.models import Message, PromptResult, TextContent +def test_payload_matches_prompt_multiple_conditions(): + """Test payload_matches for prompts with multiple conditions (OR logic).""" # Create the payload - msg = Message(role="assistant", content=TextContent(type="text", text="Hello")) - result = PromptResult(messages=[msg]) - payload = PromptPosthookPayload(prompt_id="greeting", result=result) + payload = PromptPosthookPayload(prompt_id="greeting", result={"messages": []}) # First condition fails, second condition succeeds condition1 = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) condition2 = PluginCondition(server_ids={"srv2"}, prompts={"greeting"}) context = GlobalContext(request_id="req1", server_id="srv2") - assert post_prompt_matches(payload, [condition1, condition2], context) is True + assert payload_matches(payload, "prompt_post_fetch", [condition1, condition2], context) is True # Both conditions fail context_no_match = GlobalContext(request_id="req1", server_id="srv3") - assert post_prompt_matches(payload, [condition1, condition2], context_no_match) is False + assert payload_matches(payload, "prompt_post_fetch", [condition1, condition2], context_no_match) is False # Test reset logic between conditions condition3 = PluginCondition(server_ids={"srv3"}, prompts={"other"}) condition4 = PluginCondition(prompts={"greeting"}) - assert post_prompt_matches(payload, [condition3, condition4], context_no_match) is True + assert payload_matches(payload, "prompt_post_fetch", [condition3, condition4], context_no_match) is True # ============================================================================ -# Test pre_tool_matches function +# Test payload_matches for tool hooks # ============================================================================ -def test_pre_tool_matches(): - """Test the pre_tool_matches function.""" +def test_payload_matches_tool_pre_invoke(): + """Test payload_matches for tool_pre_invoke hook.""" # Test basic matching payload = ToolPreInvokePayload(name="calculator", args={"operation": "add"}) condition = PluginCondition(tools={"calculator"}) context = GlobalContext(request_id="req1") - assert pre_tool_matches(payload, [condition], context) is True + assert payload_matches(payload, "tool_pre_invoke", [condition], context) is True # Test no match payload2 = ToolPreInvokePayload(name="other_tool", args={}) - assert pre_tool_matches(payload2, [condition], context) is False + assert payload_matches(payload2, "tool_pre_invoke", [condition], context) is False # Test with server_id condition condition_with_server = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) context_with_server = GlobalContext(request_id="req1", server_id="srv1") - assert pre_tool_matches(payload, [condition_with_server], context_with_server) is True + assert payload_matches(payload, "tool_pre_invoke", [condition_with_server], context_with_server) is True # Test with mismatched server_id context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") - assert pre_tool_matches(payload, [condition_with_server], context_wrong_server) is False + assert payload_matches(payload, "tool_pre_invoke", [condition_with_server], context_wrong_server) is False -def test_pre_tool_matches_multiple_conditions(): - """Test pre_tool_matches with multiple conditions (OR logic).""" +def test_payload_matches_tool_pre_invoke_multiple_conditions(): + """Test payload_matches for tool_pre_invoke with multiple conditions (OR logic).""" payload = ToolPreInvokePayload(name="calculator", args={"operation": "add"}) # First condition fails, second condition succeeds @@ -205,49 +202,49 @@ def test_pre_tool_matches_multiple_conditions(): condition2 = PluginCondition(server_ids={"srv2"}, tools={"calculator"}) context = GlobalContext(request_id="req1", server_id="srv2") - assert pre_tool_matches(payload, [condition1, condition2], context) is True + assert payload_matches(payload, "tool_pre_invoke", [condition1, condition2], context) is True # Both conditions fail context_no_match = GlobalContext(request_id="req1", server_id="srv3") - assert pre_tool_matches(payload, [condition1, condition2], context_no_match) is False + assert payload_matches(payload, "tool_pre_invoke", [condition1, condition2], context_no_match) is False # Test reset logic between conditions condition3 = PluginCondition(server_ids={"srv3"}, tools={"other"}) condition4 = PluginCondition(tools={"calculator"}) - assert pre_tool_matches(payload, [condition3, condition4], context_no_match) is True + assert payload_matches(payload, "tool_pre_invoke", [condition3, condition4], context_no_match) is True # ============================================================================ -# Test post_tool_matches function +# Test payload_matches for tool_post_invoke # ============================================================================ -def test_post_tool_matches(): - """Test the post_tool_matches function.""" +def test_payload_matches_tool_post_invoke(): + """Test payload_matches for tool_post_invoke hook.""" # Test basic matching payload = ToolPostInvokePayload(name="calculator", result={"value": 42}) condition = PluginCondition(tools={"calculator"}) context = GlobalContext(request_id="req1") - assert post_tool_matches(payload, [condition], context) is True + assert payload_matches(payload, "tool_post_invoke", [condition], context) is True # Test no match payload2 = ToolPostInvokePayload(name="other_tool", result={}) - assert post_tool_matches(payload2, [condition], context) is False + assert payload_matches(payload2, "tool_post_invoke", [condition], context) is False # Test with server_id condition condition_with_server = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) context_with_server = GlobalContext(request_id="req1", server_id="srv1") - assert post_tool_matches(payload, [condition_with_server], context_with_server) is True + assert payload_matches(payload, "tool_post_invoke", [condition_with_server], context_with_server) is True # Test with mismatched server_id context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") - assert post_tool_matches(payload, [condition_with_server], context_wrong_server) is False + assert payload_matches(payload, "tool_post_invoke", [condition_with_server], context_wrong_server) is False -def test_post_tool_matches_multiple_conditions(): - """Test post_tool_matches with multiple conditions (OR logic).""" +def test_payload_matches_tool_post_invoke_multiple_conditions(): + """Test payload_matches for tool_post_invoke with multiple conditions (OR logic).""" payload = ToolPostInvokePayload(name="calculator", result={"value": 42}) # First condition fails, second condition succeeds @@ -255,25 +252,25 @@ def test_post_tool_matches_multiple_conditions(): condition2 = PluginCondition(server_ids={"srv2"}, tools={"calculator"}) context = GlobalContext(request_id="req1", server_id="srv2") - assert post_tool_matches(payload, [condition1, condition2], context) is True + assert payload_matches(payload, "tool_post_invoke", [condition1, condition2], context) is True # Both conditions fail context_no_match = GlobalContext(request_id="req1", server_id="srv3") - assert post_tool_matches(payload, [condition1, condition2], context_no_match) is False + assert payload_matches(payload, "tool_post_invoke", [condition1, condition2], context_no_match) is False # Test reset logic between conditions condition3 = PluginCondition(server_ids={"srv3"}, tools={"other"}) condition4 = PluginCondition(tools={"calculator"}) - assert post_tool_matches(payload, [condition3, condition4], context_no_match) is True + assert payload_matches(payload, "tool_post_invoke", [condition3, condition4], context_no_match) is True # ============================================================================ -# Test enhanced pre_prompt_matches scenarios +# Test payload_matches for prompt_pre_fetch with multiple conditions # ============================================================================ -def test_pre_prompt_matches_multiple_conditions(): - """Test pre_prompt_matches with multiple conditions to cover OR logic paths.""" +def test_payload_matches_prompt_pre_fetch_multiple_conditions(): + """Test payload_matches for prompt_pre_fetch with multiple conditions to cover OR logic paths.""" payload = PromptPrehookPayload(prompt_id="greeting", args={}) # First condition fails, second condition succeeds @@ -281,16 +278,16 @@ def test_pre_prompt_matches_multiple_conditions(): condition2 = PluginCondition(server_ids={"srv2"}, prompts={"greeting"}) context = GlobalContext(request_id="req1", server_id="srv2") - assert pre_prompt_matches(payload, [condition1, condition2], context) is True + assert payload_matches(payload, "prompt_pre_fetch", [condition1, condition2], context) is True # Both conditions fail context_no_match = GlobalContext(request_id="req1", server_id="srv3") - assert pre_prompt_matches(payload, [condition1, condition2], context_no_match) is False + assert payload_matches(payload, "prompt_pre_fetch", [condition1, condition2], context_no_match) is False - # Test reset logic between conditions (line 140) + # Test reset logic between conditions (OR logic) condition3 = PluginCondition(server_ids={"srv3"}, prompts={"other"}) condition4 = PluginCondition(prompts={"greeting"}) - assert pre_prompt_matches(payload, [condition3, condition4], context_no_match) is True + assert payload_matches(payload, "prompt_pre_fetch", [condition3, condition4], context_no_match) is True # ============================================================================ diff --git a/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py b/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py index 8b1f0be30..c230550ad 100644 --- a/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py +++ b/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py @@ -14,11 +14,11 @@ import pytest # First-Party -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + ToolHookType, ToolPostInvokePayload, ) @@ -39,7 +39,7 @@ async def test_threshold(): plugin = ALTKJsonProcessor( # type: ignore PluginConfig( - name="jsonprocessor", kind="plugins.altk_json_processor.json_processor.ALTKJsonProcessor", hooks=[HookType.TOOL_POST_INVOKE], config={"llm_provider": "pytestmock", "length_threshold": 50} + name="jsonprocessor", kind="plugins.altk_json_processor.json_processor.ALTKJsonProcessor", hooks=[ToolHookType.TOOL_POST_INVOKE], config={"llm_provider": "pytestmock", "length_threshold": 50} ) ) ctx = PluginContext(global_context=GlobalContext(request_id="r1")) diff --git a/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py b/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py index 8368fb5dd..1f9d1db6d 100644 --- a/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py +++ b/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py @@ -11,11 +11,12 @@ import pytest # First-Party -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + PromptHookType, + ToolHookType, PromptPrehookPayload, ToolPreInvokePayload, ) @@ -30,7 +31,7 @@ def _mk_plugin(config: dict | None = None) -> ArgumentNormalizerPlugin: cfg = PluginConfig( name="arg_norm", kind="plugins.argument_normalizer.argument_normalizer.ArgumentNormalizerPlugin", - hooks=[HookType.PROMPT_PRE_FETCH, HookType.TOOL_PRE_INVOKE], + hooks=[PromptHookType.PROMPT_PRE_FETCH, ToolHookType.TOOL_PRE_INVOKE], priority=30, config=config or {}, ) diff --git a/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py b/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py index 10f2f16f7..6025a302b 100644 --- a/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py +++ b/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py @@ -9,11 +9,11 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + ToolHookType, ToolPreInvokePayload, ToolPostInvokePayload, ) @@ -26,7 +26,7 @@ async def test_cache_store_and_hit(): PluginConfig( name="cache", kind="plugins.cached_tool_result.cached_tool_result.CachedToolResultPlugin", - hooks=[HookType.TOOL_PRE_INVOKE, HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_PRE_INVOKE, ToolHookType.TOOL_POST_INVOKE], config={"cacheable_tools": ["echo"], "ttl": 60}, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py b/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py index 1de4ff24a..8429d587d 100644 --- a/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py +++ b/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py @@ -9,11 +9,11 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + ToolHookType, ToolPostInvokePayload, ) from plugins.code_safety_linter.code_safety_linter import CodeSafetyLinterPlugin @@ -25,7 +25,7 @@ async def test_detects_eval_pattern(): PluginConfig( name="csl", kind="plugins.code_safety_linter.code_safety_linter.CodeSafetyLinterPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], ) ) ctx = PluginContext(global_context=GlobalContext(request_id="r1")) diff --git a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py index e7ec89ada..6cb5a349a 100644 --- a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py +++ b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py @@ -11,12 +11,13 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, PluginViolation, + PromptHookType, + ToolHookType, PromptPrehookPayload, ToolPreInvokePayload, ToolPostInvokePayload, @@ -63,7 +64,7 @@ def _create_plugin(config_dict=None) -> ContentModerationPlugin: PluginConfig( name="content_moderation_test", kind="plugins.content_moderation.content_moderation.ContentModerationPlugin", - hooks=[HookType.PROMPT_PRE_FETCH, HookType.TOOL_PRE_INVOKE], + hooks=[PromptHookType.PROMPT_PRE_FETCH, ToolHookType.TOOL_PRE_INVOKE], config=default_config, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py index b443876bc..8c5202b3a 100644 --- a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py +++ b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py @@ -13,11 +13,13 @@ import pytest from mcpgateway.plugins.framework.manager import PluginManager -from mcpgateway.plugins.framework.models import ( - GlobalContext, +from mcpgateway.plugins.framework import GlobalContext + +from mcpgateway.plugins.framework import ( + PromptHookType, + ToolHookType, PromptPrehookPayload, ToolPreInvokePayload, - ToolPostInvokePayload, ) @@ -111,7 +113,7 @@ async def test_content_moderation_with_manager(): args={"query": "What is the weather like today?"} ) - result, final_context = await manager.prompt_pre_fetch(payload, context) + result, final_context = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, context) # Verify result assert result.continue_processing is True @@ -194,7 +196,7 @@ async def test_content_moderation_blocking_harmful_content(): args={"query": "I hate all those people and want them gone"} ) - result, final_context = await manager.prompt_pre_fetch(payload, context) + result, final_context = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, context) # Should be blocked due to high hate score assert result.continue_processing is False @@ -270,7 +272,7 @@ async def test_content_moderation_with_granite_fallback(): args={"query": "How to resolve conflicts peacefully"} ) - result, final_context = await manager.tool_pre_invoke(payload, context) + result, final_context = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, payload, context) # Should continue processing (fallback succeeded) assert result.continue_processing is True @@ -351,7 +353,7 @@ async def test_content_moderation_redaction(): args={"query": "This damn thing is not working"} ) - result, final_context = await manager.prompt_pre_fetch(payload, context) + result, final_context = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, context) # Should continue processing but with modified content assert result.continue_processing is True @@ -442,7 +444,7 @@ async def test_content_moderation_multiple_providers(): args={"query": "What is machine learning?"} ) - prompt_result, _ = await manager.prompt_pre_fetch(prompt_payload, context) + prompt_result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt_payload, context) assert prompt_result.continue_processing is True # Test tool (goes to Granite) @@ -451,7 +453,7 @@ async def test_content_moderation_multiple_providers(): args={"query": "How to build AI models"} ) - tool_result, _ = await manager.tool_pre_invoke(tool_payload, context) + tool_result, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, context) assert tool_result.continue_processing is True # Verify both providers were called diff --git a/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py b/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py index f19dfe214..baadb334d 100644 --- a/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py +++ b/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py @@ -9,15 +9,16 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + ResourceHookType, ResourcePostFetchPayload, ResourcePreFetchPayload, ) -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from plugins.external.clamav_server.clamav_plugin import ClamAVRemotePlugin @@ -29,7 +30,7 @@ def _mk_plugin(block_on_positive: bool = True) -> ClamAVRemotePlugin: cfg = PluginConfig( name="clamav", kind="plugins.external.clamav_server.clamav_plugin.ClamAVRemotePlugin", - hooks=[HookType.RESOURCE_PRE_FETCH, HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH, ResourceHookType.RESOURCE_POST_FETCH], config={ "mode": "eicar_only", "block_on_positive": block_on_positive, @@ -77,13 +78,13 @@ async def test_non_blocking_mode_reports_metadata(tmp_path): @pytest.mark.asyncio async def test_prompt_post_fetch_blocks_on_eicar_text(): plugin = _mk_plugin(True) - from mcpgateway.plugins.framework.models import PromptPosthookPayload + from mcpgateway.plugins.framework import PromptPosthookPayload - pr = __import__("mcpgateway.models").models.PromptResult( + pr = PromptResult( messages=[ - __import__("mcpgateway.models").models.Message( + Message( role="assistant", - content=__import__("mcpgateway.models").models.TextContent(type="text", text=EICAR), + content=TextContent(type="text", text=EICAR), ) ] ) @@ -97,7 +98,7 @@ async def test_prompt_post_fetch_blocks_on_eicar_text(): @pytest.mark.asyncio async def test_tool_post_invoke_blocks_on_eicar_string(): plugin = _mk_plugin(True) - from mcpgateway.plugins.framework.models import ToolPostInvokePayload + from mcpgateway.plugins.framework import ToolPostInvokePayload ctx = PluginContext(global_context=GlobalContext(request_id="r5")) payload = ToolPostInvokePayload(name="t", result={"text": EICAR}) @@ -118,13 +119,13 @@ async def test_health_stats_counters(): await plugin.resource_post_fetch(payload_r, ctx) # 2) prompt_post_fetch with EICAR -> attempted +1, infected +1 (total attempted=2, infected=2) - from mcpgateway.plugins.framework.models import PromptPosthookPayload + from mcpgateway.plugins.framework import PromptPosthookPayload - pr = __import__("mcpgateway.models").models.PromptResult( + pr = PromptResult( messages=[ - __import__("mcpgateway.models").models.Message( + Message( role="assistant", - content=__import__("mcpgateway.models").models.TextContent(type="text", text=EICAR), + content=TextContent(type="text", text=EICAR), ) ] ) @@ -132,7 +133,7 @@ async def test_health_stats_counters(): await plugin.prompt_post_fetch(payload_p, ctx) # 3) tool_post_invoke with one EICAR and one clean string -> attempted +2, infected +1 - from mcpgateway.plugins.framework.models import ToolPostInvokePayload + from mcpgateway.plugins.framework import ToolPostInvokePayload payload_t = ToolPostInvokePayload(name="t", result={"a": EICAR, "b": "clean"}) await plugin.tool_post_invoke(payload_t, ctx) diff --git a/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py b/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py index e58430b9b..82d809c4d 100644 --- a/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py +++ b/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py @@ -9,15 +9,15 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + ResourceHookType, ResourcePreFetchPayload, ResourcePostFetchPayload, ) -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from plugins.file_type_allowlist.file_type_allowlist import FileTypeAllowlistPlugin @@ -27,7 +27,7 @@ async def test_blocks_disallowed_extension_and_mime(): PluginConfig( name="fta", kind="plugins.file_type_allowlist.file_type_allowlist.FileTypeAllowlistPlugin", - hooks=[HookType.RESOURCE_PRE_FETCH, HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH, ResourceHookType.RESOURCE_POST_FETCH], config={"allowed_extensions": [".md"], "allowed_mime_types": ["text/markdown"]}, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py b/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py index a25d54fd8..165ea9c67 100644 --- a/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py +++ b/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py @@ -9,14 +9,14 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + ResourceHookType, ResourcePostFetchPayload, ) -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from plugins.html_to_markdown.html_to_markdown import HTMLToMarkdownPlugin @@ -26,7 +26,7 @@ async def test_html_to_markdown_transforms_basic_html(): PluginConfig( name="html2md", kind="plugins.html_to_markdown.html_to_markdown.HTMLToMarkdownPlugin", - hooks=[HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_POST_FETCH], ) ) html = "

Title

Hello link

print('x')
" diff --git a/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py b/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py index d6ca40917..07e089d24 100644 --- a/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py +++ b/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py @@ -10,11 +10,11 @@ import json import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + ToolHookType, ToolPostInvokePayload, ) from plugins.json_repair.json_repair import JSONRepairPlugin @@ -26,7 +26,7 @@ async def test_repairs_trailing_commas_and_single_quotes(): PluginConfig( name="jsonr", kind="plugins.json_repair.json_repair.JSONRepairPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], ) ) ctx = PluginContext(global_context=GlobalContext(request_id="r1")) diff --git a/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py b/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py index e2b4c0df1..9f469f0ec 100644 --- a/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py +++ b/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py @@ -9,12 +9,12 @@ import pytest -from mcpgateway.models import Message, PromptResult, TextContent -from mcpgateway.plugins.framework.models import ( +from mcpgateway.common.models import Message, PromptResult, TextContent +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + PromptHookType, PromptPosthookPayload, ) from plugins.markdown_cleaner.markdown_cleaner import MarkdownCleanerPlugin @@ -26,7 +26,7 @@ async def test_cleans_markdown_prompt(): PluginConfig( name="mdclean", kind="plugins.markdown_cleaner.markdown_cleaner.MarkdownCleanerPlugin", - hooks=[HookType.PROMPT_POST_FETCH], + hooks=[PromptHookType.PROMPT_POST_FETCH], ) ) txt = "#Heading\n\n\n* item\n\n```\n\n```\n" diff --git a/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py b/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py index 621d98cc9..37e0796e9 100644 --- a/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py +++ b/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py @@ -8,11 +8,11 @@ """ # First-Party -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + ToolHookType, ToolPostInvokePayload, ) @@ -27,7 +27,7 @@ def _mk_plugin(config: dict | None = None) -> OutputLengthGuardPlugin: cfg = PluginConfig( name="out_len_guard", kind="plugins.output_length_guard.output_length_guard.OutputLengthGuardPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], priority=90, config=config or {}, ) diff --git a/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py index 23440ea33..bd4979abd 100644 --- a/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py +++ b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py @@ -11,13 +11,13 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent -from mcpgateway.plugins.framework.models import ( +from mcpgateway.common.models import Message, PromptResult, Role, TextContent +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, PluginMode, + PromptHookType, PromptPosthookPayload, PromptPrehookPayload, ) @@ -229,7 +229,7 @@ def plugin_config(self) -> PluginConfig: author="Test", kind="plugins.pii_filter.pii_filter.PIIFilterPlugin", version="1.0", - hooks=[HookType.PROMPT_PRE_FETCH, HookType.PROMPT_POST_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH, PromptHookType.PROMPT_POST_FETCH], tags=["test", "pii"], mode=PluginMode.ENFORCE, priority=10, @@ -414,7 +414,7 @@ async def test_integration_with_manager(): payload = PromptPrehookPayload(prompt_id="test_prompt", args={"input": "Email: test@example.com, SSN: 123-45-6789"}) global_context = GlobalContext(request_id="test-manager") - result, contexts = await manager.prompt_pre_fetch(payload, global_context) + result, contexts = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) # Verify PII was masked assert result.modified_payload is not None diff --git a/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py b/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py index 4e1bad235..2ee6d0db3 100644 --- a/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py +++ b/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py @@ -9,12 +9,13 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + PromptHookType, PromptPrehookPayload, + ToolHookType ) from plugins.rate_limiter.rate_limiter import RateLimiterPlugin @@ -24,7 +25,7 @@ def _mk(rate: str) -> RateLimiterPlugin: PluginConfig( name="rl", kind="plugins.rate_limiter.rate_limiter.RateLimiterPlugin", - hooks=[HookType.PROMPT_PRE_FETCH, HookType.TOOL_PRE_INVOKE], + hooks=[PromptHookType.PROMPT_PRE_FETCH, ToolHookType.TOOL_PRE_INVOKE], config={"by_user": rate}, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py b/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py index 08f12cf72..bbe2032f2 100644 --- a/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py +++ b/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py @@ -11,13 +11,13 @@ import pytest # First-Party -from mcpgateway.models import ResourceContent -from mcpgateway.plugins.framework.models import ( +from mcpgateway.common.models import ResourceContent +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, PluginMode, + ResourceHookType, ResourcePostFetchPayload, ResourcePreFetchPayload, ) @@ -36,7 +36,7 @@ def plugin_config(self): author="test", kind="plugins.resource_filter.resource_filter.ResourceFilterPlugin", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH, HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH, ResourceHookType.RESOURCE_POST_FETCH], tags=["test", "filter"], mode=PluginMode.ENFORCE, config={ diff --git a/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py b/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py index 1f04cc08a..0dd8cf008 100644 --- a/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py +++ b/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py @@ -9,11 +9,11 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + ToolHookType, ToolPreInvokePayload, ToolPostInvokePayload, ) @@ -37,7 +37,7 @@ async def test_schema_guard_valid_and_invalid(): PluginConfig( name="sg", kind="plugins.schema_guard.schema_guard.SchemaGuardPlugin", - hooks=[HookType.TOOL_PRE_INVOKE, HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_PRE_INVOKE, ToolHookType.TOOL_POST_INVOKE], config=cfg, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py b/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py index 649efe5e6..a8eb15a83 100644 --- a/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py +++ b/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py @@ -9,11 +9,11 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + ResourceHookType, ResourcePreFetchPayload, ) from plugins.url_reputation.url_reputation import URLReputationPlugin @@ -25,7 +25,7 @@ async def test_blocks_blocklisted_domain(): PluginConfig( name="urlrep", kind="plugins.url_reputation.url_reputation.URLReputationPlugin", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], config={"blocked_domains": ["bad.example"]}, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py b/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py index 01eddc28a..2e9a04395 100644 --- a/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py +++ b/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py @@ -13,16 +13,18 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, + PromptHookType, + ResourceHookType, + ToolHookType, ResourcePreFetchPayload, ) from plugins.virus_total_checker.virus_total_checker import VirusTotalURLCheckerPlugin -from mcpgateway.models import Message, PromptResult, TextContent +from mcpgateway.common.models import Message, PromptResult, TextContent class _Resp: @@ -68,7 +70,7 @@ async def test_url_block_on_malicious(tmp_path, monkeypatch): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], config={ "enabled": True, "check_url": True, @@ -134,7 +136,7 @@ async def test_local_allow_and_deny_overrides(): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config={ "enabled": True, "scan_tool_outputs": True, @@ -144,7 +146,7 @@ async def test_local_allow_and_deny_overrides(): plugin = VirusTotalURLCheckerPlugin(cfg) plugin._client_factory = lambda c, h: _StubClient(routes) # type: ignore os.environ["VT_API_KEY"] = "dummy" - from mcpgateway.plugins.framework.models import ToolPostInvokePayload + from mcpgateway.plugins.framework import ToolPostInvokePayload payload = ToolPostInvokePayload(name="writer", result=f"See {url}") ctx = PluginContext(global_context=GlobalContext(request_id="r7")) res = await plugin.tool_post_invoke(payload, ctx) @@ -155,7 +157,7 @@ async def test_local_allow_and_deny_overrides(): cfg2 = PluginConfig( name="vt2", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config={ "enabled": True, "scan_tool_outputs": True, @@ -178,7 +180,7 @@ async def test_override_precedence_allow_over_deny_vs_deny_over_allow(): cfg_allow = PluginConfig( name="vt-allow", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config={ "enabled": True, "scan_tool_outputs": True, @@ -190,7 +192,7 @@ async def test_override_precedence_allow_over_deny_vs_deny_over_allow(): plugin_allow = VirusTotalURLCheckerPlugin(cfg_allow) plugin_allow._client_factory = lambda c, h: _StubClient({}) # type: ignore os.environ["VT_API_KEY"] = "dummy" - from mcpgateway.plugins.framework.models import ToolPostInvokePayload + from mcpgateway.plugins.framework import ToolPostInvokePayload payload = ToolPostInvokePayload(name="writer", result=f"visit {url}") ctx = PluginContext(global_context=GlobalContext(request_id="r8")) res_allow = await plugin_allow.tool_post_invoke(payload, ctx) @@ -200,7 +202,7 @@ async def test_override_precedence_allow_over_deny_vs_deny_over_allow(): cfg_deny = PluginConfig( name="vt-deny", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config={ "enabled": True, "scan_tool_outputs": True, @@ -221,7 +223,7 @@ async def test_prompt_scan_blocks_on_url(): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.PROMPT_POST_FETCH], + hooks=[PromptHookType.PROMPT_POST_FETCH], config={ "enabled": True, "scan_prompt_outputs": True, @@ -249,7 +251,7 @@ async def test_prompt_scan_blocks_on_url(): os.environ["VT_API_KEY"] = "dummy" pr = PromptResult(messages=[Message(role="assistant", content=TextContent(type="text", text=f"see {url}"))]) - from mcpgateway.plugins.framework.models import PromptPosthookPayload + from mcpgateway.plugins.framework import PromptPosthookPayload payload = PromptPosthookPayload(prompt_id="p", result=pr) ctx = PluginContext(global_context=GlobalContext(request_id="r5")) res = await plugin.prompt_post_fetch(payload, ctx) @@ -262,7 +264,7 @@ async def test_resource_scan_blocks_on_url(): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_POST_FETCH], config={ "enabled": True, "scan_resource_contents": True, @@ -289,9 +291,9 @@ async def test_resource_scan_blocks_on_url(): plugin._client_factory = lambda c, h: _StubClient(routes) # type: ignore os.environ["VT_API_KEY"] = "dummy" - from mcpgateway.models import ResourceContent + from mcpgateway.common.models import ResourceContent rc = ResourceContent(type="resource", id="345",uri="test://x", mime_type="text/plain", text=f"{url} is fishy") - from mcpgateway.plugins.framework.models import ResourcePostFetchPayload + from mcpgateway.plugins.framework import ResourcePostFetchPayload payload = ResourcePostFetchPayload(uri="test://x", content=rc) ctx = PluginContext(global_context=GlobalContext(request_id="r6")) res = await plugin.resource_post_fetch(payload, ctx) @@ -309,7 +311,7 @@ async def test_file_hash_lookup_blocks(tmp_path, monkeypatch): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], config={ "enabled": True, "enable_file_checks": True, @@ -353,7 +355,7 @@ async def test_unknown_file_then_upload_wait_allows_when_clean(tmp_path): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], config={ "enabled": True, "enable_file_checks": True, @@ -402,7 +404,7 @@ async def test_tool_output_url_block_and_ratio(): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config={ "enabled": True, "scan_tool_outputs": True, @@ -433,7 +435,7 @@ async def test_tool_output_url_block_and_ratio(): plugin._client_factory = lambda c, h: _StubClient(routes) # type: ignore os.environ["VT_API_KEY"] = "dummy" - from mcpgateway.plugins.framework.models import ToolPostInvokePayload + from mcpgateway.plugins.framework import ToolPostInvokePayload payload = ToolPostInvokePayload(name="writer", result=f"See {url} for details") ctx = PluginContext(global_context=GlobalContext(request_id="r4")) diff --git a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py index 6307f651a..22a353c19 100644 --- a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py +++ b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py @@ -14,10 +14,11 @@ import pytest from mcpgateway.plugins.framework.manager import PluginManager -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - ToolPostInvokePayload, - PluginViolation, + PromptHookType, + ToolHookType, + ToolPostInvokePayload ) @@ -81,7 +82,7 @@ async def test_webhook_plugin_with_manager(): ) # Execute tool post-invoke hook - result, final_context = await manager.tool_post_invoke(payload, context) + result, final_context = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, context) # Verify result assert result.continue_processing is True @@ -164,14 +165,14 @@ async def test_webhook_plugin_violation_handling(): context = GlobalContext(request_id="violation-test", user="testuser") # Create payload with forbidden word that will trigger deny filter - from mcpgateway.plugins.framework.models import PromptPrehookPayload + from mcpgateway.plugins.framework import PromptPrehookPayload payload = PromptPrehookPayload( prompt_id="test_prompt", args={"query": "this contains forbidden word"} ) # Execute - should be blocked by deny filter - result, final_context = await manager.prompt_pre_fetch(payload, context) + result, final_context = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, context) # Verify the request was blocked assert result.continue_processing is False @@ -248,7 +249,7 @@ async def test_webhook_plugin_multiple_webhooks(): ) # Execute hook - result, final_context = await manager.tool_post_invoke(payload, context) + result, final_context = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, context) assert result.continue_processing is True @@ -341,7 +342,7 @@ async def test_webhook_plugin_template_customization(): result={"data": "test"} ) - await manager.tool_post_invoke(payload, context) + await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, context) # Verify webhook was called with custom template mock_client.post.assert_called_once() diff --git a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py index 23319275a..a05c41b93 100644 --- a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py +++ b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py @@ -11,12 +11,12 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, PluginViolation, + ToolHookType, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload, @@ -53,7 +53,7 @@ def _create_plugin(config_dict=None) -> WebhookNotificationPlugin: PluginConfig( name="webhook_test", kind="plugins.webhook_notification.webhook_notification.WebhookNotificationPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config=default_config, ) ) @@ -463,7 +463,8 @@ async def test_prompt_pre_and_post_hooks_return_success(self): # Test post-hook with mock notification plugin._notify_webhooks = AsyncMock() - from mcpgateway.plugins.framework.models import PromptPosthookPayload, PromptResult + from mcpgateway.plugins.framework import PromptPosthookPayload + from mcpgateway.common.models import PromptResult post_payload = PromptPosthookPayload( prompt_id="test_prompt", result=PromptResult(messages=[]) diff --git a/tests/unit/mcpgateway/services/test_completion_service.py b/tests/unit/mcpgateway/services/test_completion_service.py index e7fe866e2..f46a65d1a 100644 --- a/tests/unit/mcpgateway/services/test_completion_service.py +++ b/tests/unit/mcpgateway/services/test_completion_service.py @@ -9,7 +9,7 @@ import pytest # First-Party -from mcpgateway.models import ( +from mcpgateway.common.models import ( CompleteResult, ) from mcpgateway.services.completion_service import ( diff --git a/tests/unit/mcpgateway/services/test_export_service.py b/tests/unit/mcpgateway/services/test_export_service.py index b01889b26..48d0d3652 100644 --- a/tests/unit/mcpgateway/services/test_export_service.py +++ b/tests/unit/mcpgateway/services/test_export_service.py @@ -15,7 +15,7 @@ import pytest # First-Party -from mcpgateway.models import Root +from mcpgateway.common.models import Root from mcpgateway.schemas import GatewayRead, PromptMetrics, PromptRead, ResourceMetrics, ResourceRead, ServerMetrics, ServerRead, ToolMetrics, ToolRead from mcpgateway.services.export_service import ExportError, ExportService, ExportValidationError from mcpgateway.utils.services_auth import encode_auth @@ -971,7 +971,7 @@ async def test_export_selective_all_entity_types(export_service, mock_db): export_service.resource_service.list_resources.return_value = ([sample_resource], None) # First-Party - from mcpgateway.models import Root + from mcpgateway.common.models import Root mock_roots = [Root(uri="file:///workspace", name="Workspace")] export_service.root_service.list_roots.return_value = mock_roots diff --git a/tests/unit/mcpgateway/services/test_log_storage_service.py b/tests/unit/mcpgateway/services/test_log_storage_service.py index 15c1742be..414e02ebc 100644 --- a/tests/unit/mcpgateway/services/test_log_storage_service.py +++ b/tests/unit/mcpgateway/services/test_log_storage_service.py @@ -16,7 +16,7 @@ import pytest # First-Party -from mcpgateway.models import LogLevel +from mcpgateway.common.models import LogLevel from mcpgateway.services.log_storage_service import LogEntry, LogStorageService diff --git a/tests/unit/mcpgateway/services/test_logging_service.py b/tests/unit/mcpgateway/services/test_logging_service.py index e8ae79b27..933852577 100644 --- a/tests/unit/mcpgateway/services/test_logging_service.py +++ b/tests/unit/mcpgateway/services/test_logging_service.py @@ -26,7 +26,7 @@ import pytest # First-Party -from mcpgateway.models import LogLevel +from mcpgateway.common.models import LogLevel from mcpgateway.services.logging_service import LoggingService # --------------------------------------------------------------------------- diff --git a/tests/unit/mcpgateway/services/test_logging_service_comprehensive.py b/tests/unit/mcpgateway/services/test_logging_service_comprehensive.py index 34bd98d67..5962b487a 100644 --- a/tests/unit/mcpgateway/services/test_logging_service_comprehensive.py +++ b/tests/unit/mcpgateway/services/test_logging_service_comprehensive.py @@ -18,7 +18,7 @@ import pytest # First-Party -from mcpgateway.models import LogLevel +from mcpgateway.common.models import LogLevel from mcpgateway.services.logging_service import _get_file_handler, _get_text_handler, LoggingService # --------------------------------------------------------------------------- diff --git a/tests/unit/mcpgateway/services/test_prompt_service.py b/tests/unit/mcpgateway/services/test_prompt_service.py index 56d8e38c9..a69650fc1 100644 --- a/tests/unit/mcpgateway/services/test_prompt_service.py +++ b/tests/unit/mcpgateway/services/test_prompt_service.py @@ -29,7 +29,7 @@ # First-Party from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import PromptMetric -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate from mcpgateway.services.prompt_service import ( diff --git a/tests/unit/mcpgateway/services/test_resource_service_plugins.py b/tests/unit/mcpgateway/services/test_resource_service_plugins.py index 05c966816..fd3fdf513 100644 --- a/tests/unit/mcpgateway/services/test_resource_service_plugins.py +++ b/tests/unit/mcpgateway/services/test_resource_service_plugins.py @@ -16,7 +16,7 @@ from sqlalchemy.orm import Session # First-Party -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from mcpgateway.services.resource_service import ResourceNotFoundError, ResourceService from mcpgateway.plugins.framework import PluginError, PluginErrorModel, PluginViolation, PluginViolationError @@ -39,11 +39,21 @@ def resource_service(self): @pytest.fixture def resource_service_with_plugins(self): """Create a ResourceService instance with plugins enabled.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + with patch.dict(os.environ, {"PLUGINS_ENABLED": "true", "PLUGIN_CONFIG_FILE": "test_config.yaml"}): with patch("mcpgateway.services.resource_service.PluginManager") as MockPluginManager: mock_manager = MagicMock() mock_manager._initialized = False mock_manager.initialize = AsyncMock() + # Add default invoke_hook mock that returns success + mock_manager.invoke_hook = AsyncMock( + return_value=( + PluginResult(continue_processing=True, modified_payload=None), + None # contexts + ) + ) MockPluginManager.return_value = mock_manager service = ResourceService() service._plugin_manager = mock_manager @@ -70,6 +80,9 @@ async def test_read_resource_without_plugins(self, resource_service, mock_db): @pytest.mark.asyncio async def test_read_resource_with_pre_fetch_hook(self, resource_service_with_plugins, mock_db): """Test read_resource with pre-fetch hook execution.""" + # First-Party + from mcpgateway.plugins.framework import ResourceHookType + import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True service = resource_service_with_plugins @@ -87,33 +100,6 @@ async def test_read_resource_with_pre_fetch_hook(self, resource_service_with_plu mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None - # Setup pre-fetch hook response - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock( - continue_processing=True, - modified_payload=None, - violation=None, - ), - {"context": "data"}, # contexts - ) - ) - - # Setup post-fetch hook response - mock_manager.resource_post_fetch = AsyncMock( - return_value=( - MagicMock( - continue_processing=True, - modified_payload=None, - ), - None, - ) - ) - - # Explicitly call initialize if not already called - if hasattr(mock_manager.initialize, 'await_count') and mock_manager.initialize.await_count == 0: - await mock_manager.initialize() - result = await service.read_resource( mock_db, "test://resource", @@ -123,14 +109,14 @@ async def test_read_resource_with_pre_fetch_hook(self, resource_service_with_plu # Verify hooks were called mock_manager.initialize.assert_called() - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_called_once() + assert mock_manager.invoke_hook.call_count >= 2 # Pre and post fetch - # Verify context was passed correctly - call_args = mock_manager.resource_pre_fetch.call_args - assert call_args[0][0].uri == "test://resource" # payload - assert call_args[0][1].request_id == "test-123" # global_context - assert call_args[0][1].user == "testuser" + # Verify context was passed correctly - check first call (pre-fetch) + first_call = mock_manager.invoke_hook.call_args_list[0] + assert first_call[0][0] == ResourceHookType.RESOURCE_PRE_FETCH # hook_type + assert first_call[0][1].uri == "test://resource" # payload + assert first_call[0][2].request_id == "test-123" # global_context + assert first_call[0][2].user == "testuser" @pytest.mark.asyncio async def test_read_resource_blocked_by_plugin(self, resource_service_with_plugins, mock_db): @@ -152,8 +138,8 @@ async def test_read_resource_blocked_by_plugin(self, resource_service_with_plugi mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None - # Setup pre-fetch hook to block - mock_manager.resource_pre_fetch = AsyncMock( + # Setup invoke_hook to raise PluginViolationError + mock_manager.invoke_hook = AsyncMock( side_effect=PluginViolationError(message="Protocol not allowed", violation=PluginViolation( reason="Protocol not allowed", @@ -168,13 +154,15 @@ async def test_read_resource_blocked_by_plugin(self, resource_service_with_plugi await service.read_resource(mock_db, "file:///etc/passwd") assert "Protocol not allowed" in str(exc_info.value) - mock_manager.resource_pre_fetch.assert_called_once() - # Post-fetch should not be called if pre-fetch blocks - mock_manager.resource_post_fetch.assert_not_called() + mock_manager.invoke_hook.assert_called() @pytest.mark.asyncio async def test_read_resource_uri_modified_by_plugin(self, resource_service_with_plugins, mock_db): """Test read_resource with URI modification by plugin.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.framework import ResourceHookType + service = resource_service_with_plugins mock_manager = service._plugin_manager @@ -193,26 +181,27 @@ async def test_read_resource_uri_modified_by_plugin(self, resource_service_with_ # Setup pre-fetch hook to modify URI modified_payload = MagicMock() modified_payload.uri = "cached://test://resource" - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock( - continue_processing=True, - modified_payload=modified_payload, - ), - {"context": "data"}, - ) - ) - # Setup post-fetch hook - mock_manager.resource_post_fetch = AsyncMock( - return_value=( - MagicMock( + # Use side_effect to return different results based on hook type + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: + return ( + PluginResult( + continue_processing=True, + modified_payload=modified_payload, + ), + {"context": "data"}, + ) + # POST_FETCH + return ( + PluginResult( continue_processing=True, modified_payload=None, ), None, ) - ) + + mock_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) result = await service.read_resource(mock_db, "test://resource") @@ -223,6 +212,10 @@ async def test_read_resource_uri_modified_by_plugin(self, resource_service_with_ @pytest.mark.asyncio async def test_read_resource_content_filtered_by_plugin(self, resource_service_with_plugins, mock_db): """Test read_resource with content filtering by post-fetch hook.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.framework import ResourceHookType + import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True service = resource_service_with_plugins @@ -244,14 +237,6 @@ def scalar_one_or_none_side_effect(*args, **kwargs): mock_db.execute.return_value.scalar_one_or_none.side_effect = scalar_one_or_none_side_effect mock_db.get.return_value = mock_resource - # Setup pre-fetch hook - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock(continue_processing=True), - {"context": "data"}, - ) - ) - # Setup post-fetch hook to filter content filtered_content = ResourceContent( type="resource", @@ -260,17 +245,26 @@ def scalar_one_or_none_side_effect(*args, **kwargs): text="password: [REDACTED]\napi_key: [REDACTED]", ) resource_id = filtered_content.id - modified_payload = MagicMock() - modified_payload.content = filtered_content - mock_manager.resource_post_fetch = AsyncMock( - return_value=( - MagicMock( + modified_post_payload = MagicMock() + modified_post_payload.content = filtered_content + + # Use side_effect to return different results based on hook type + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: + return ( + PluginResult(continue_processing=True), + {"context": "data"}, + ) + # POST_FETCH + return ( + PluginResult( continue_processing=True, - modified_payload=modified_payload, + modified_payload=modified_post_payload, ), None, ) - ) + + mock_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) result = await service.read_resource(mock_db, resource_id) @@ -303,17 +297,21 @@ async def test_read_resource_plugin_error_handling(self, resource_service_with_p mock_db.get.return_value = mock_resource # Ensure resource_db is not None # Setup pre-fetch hook to raise an error - mock_manager.resource_pre_fetch = AsyncMock(side_effect=PluginError(error=PluginErrorModel(message="Plugin error", plugin_name="mock_plugin"))) + mock_manager.invoke_hook = AsyncMock(side_effect=PluginError(error=PluginErrorModel(message="Plugin error", plugin_name="mock_plugin"))) with pytest.raises(PluginError) as exc_info: await service.read_resource(mock_db, resource_id) - mock_manager.resource_pre_fetch.assert_called_once() + mock_manager.invoke_hook.assert_called_once() @pytest.mark.asyncio async def test_read_resource_post_fetch_blocking(self, resource_service_with_plugins, mock_db): """Test read_resource blocked by post-fetch hook.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.framework import ResourceHookType + import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True service = resource_service_with_plugins @@ -331,30 +329,32 @@ async def test_read_resource_post_fetch_blocking(self, resource_service_with_plu mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None - # Setup pre-fetch hook - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock(continue_processing=True), - {"context": "data"}, + # Use side_effect to allow pre-fetch but block on post-fetch + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: + return ( + PluginResult(continue_processing=True), + {"context": "data"}, + ) + # POST_FETCH - raise error + raise PluginViolationError( + message="Content contains sensitive data", + violation=PluginViolation( + reason="Content contains sensitive data", + description="The resource content was flagged as containing sensitive information", + code="SENSITIVE_CONTENT", + details={"uri": "test://resource"} + ) ) - ) - # Setup post-fetch hook to block - mock_manager.resource_post_fetch = AsyncMock( - side_effect=PluginViolationError(message="Content contains sensitive data", - violation=PluginViolation( - reason="Content contains sensitive data", - description="The resource content was flagged as containing sensitive information", - code="SENSITIVE_CONTENT", - details={"uri": "test://resource"} - )) - ) + mock_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) with pytest.raises(PluginViolationError) as exc_info: await service.read_resource(mock_db, "test://resource") assert "Content contains sensitive data" in str(exc_info.value) - mock_manager.resource_post_fetch.assert_called_once() + # Verify invoke_hook was called at least twice (pre and post) + assert mock_manager.invoke_hook.call_count == 2 @pytest.mark.asyncio async def test_read_resource_with_template(self, resource_service_with_plugins, mock_db): @@ -377,32 +377,23 @@ async def test_read_resource_with_template(self, resource_service_with_plugins, mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None - # Setup hooks - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock(continue_processing=True), - {"context": "data"}, - ) - ) - # Create a mock result with modified_payload explicitly set to None - mock_post_result = MagicMock() - mock_post_result.continue_processing = True - mock_post_result.modified_payload = None - - mock_manager.resource_post_fetch = AsyncMock( - return_value=(mock_post_result, None) - ) + # The default invoke_hook from fixture will work fine for this test + # since it just returns success with no modifications # Use the correct resource id for lookup result = await service.read_resource(mock_db, mock_resource.uri) assert result == mock_template_content - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_called_once() + # Verify hooks were called + assert mock_manager.invoke_hook.call_count >= 2 # Pre and post fetch @pytest.mark.asyncio async def test_read_resource_context_propagation(self, resource_service_with_plugins, mock_db): """Test context propagation from pre-fetch to post-fetch.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.framework import ResourceHookType + import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True service = resource_service_with_plugins @@ -422,28 +413,31 @@ async def test_read_resource_context_propagation(self, resource_service_with_plu # Capture contexts from pre-fetch test_contexts = {"plugin1": {"validated": True}} - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock(continue_processing=True), - test_contexts, - ) - ) - # Verify contexts passed to post-fetch - mock_manager.resource_post_fetch = AsyncMock( - return_value=( - MagicMock(continue_processing=True), + # Use side_effect to return contexts from pre-fetch + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: + return ( + PluginResult(continue_processing=True), + test_contexts, + ) + # POST_FETCH + return ( + PluginResult(continue_processing=True), None, ) - ) + + mock_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) # The resource id must match the lookup for plugin logic to trigger await service.read_resource(mock_db, mock_resource.content.id) # Verify contexts were passed from pre to post - post_call_args = mock_manager.resource_post_fetch.call_args - assert post_call_args is not None, "resource_post_fetch was not called" - assert post_call_args[0][2] == test_contexts # Third argument is contexts + assert mock_manager.invoke_hook.call_count == 2 + # Check second call (post-fetch) to verify contexts were passed + post_call_args = mock_manager.invoke_hook.call_args_list[1] + # The contexts dict should be passed as the 4th positional arg (local_contexts) + assert post_call_args[0][3] == test_contexts # Fourth argument is local_contexts @pytest.mark.asyncio async def test_read_resource_inactive_resource(self, resource_service, mock_db): @@ -496,19 +490,13 @@ async def test_read_resource_no_request_id(self, resource_service_with_plugins, mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None - # Setup hooks - mock_manager.resource_pre_fetch = AsyncMock( - return_value=(MagicMock(continue_processing=True), None) - ) - mock_manager.resource_post_fetch = AsyncMock( - return_value=(MagicMock(continue_processing=True), None) - ) + # The default invoke_hook from fixture will work fine await service.read_resource(mock_db, "test://resource") - # Verify request_id was generated - call_args = mock_manager.resource_pre_fetch.call_args - assert call_args is not None, "resource_pre_fetch was not called" - global_context = call_args[0][1] + # Verify request_id was generated - check first call (pre-fetch) + assert mock_manager.invoke_hook.call_count >= 1, "invoke_hook was not called" + first_call = mock_manager.invoke_hook.call_args_list[0] + global_context = first_call[0][2] # Third positional arg is global_context assert global_context.request_id is not None assert len(global_context.request_id) > 0 diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index 00887bc46..dec3854b7 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -2234,6 +2234,10 @@ def mock_passthrough(req_headers, tool_headers, db_session, gateway=None): async def test_invoke_tool_with_plugin_post_invoke_success(self, tool_service, mock_tool, test_db): """Test invoking tool with successful plugin post-invoke hook.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.framework import ToolHookType + # Configure tool as REST mock_tool.integration_type = "REST" mock_tool.request_type = "POST" @@ -2251,15 +2255,21 @@ async def test_invoke_tool_with_plugin_post_invoke_success(self, tool_service, m mock_response.json = Mock(return_value={"result": "original response"}) tool_service._http_client.request.return_value = mock_response - # Mock plugin manager and post-invoke hook + # Mock plugin manager with invoke_hook mock_post_result = Mock() mock_post_result.continue_processing = True mock_post_result.violation = None mock_post_result.modified_payload = None tool_service._plugin_manager = Mock() - tool_service._plugin_manager.tool_pre_invoke = AsyncMock(return_value=(Mock(continue_processing=True, violation=None, modified_payload=None), None)) - tool_service._plugin_manager.tool_post_invoke = AsyncMock(return_value=(mock_post_result, None)) + + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == ToolHookType.TOOL_PRE_INVOKE: + return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) + # POST_INVOKE + return (mock_post_result, None) + + tool_service._plugin_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), @@ -2267,8 +2277,8 @@ async def test_invoke_tool_with_plugin_post_invoke_success(self, tool_service, m ): result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) - # Verify plugin post-invoke was called - tool_service._plugin_manager.tool_post_invoke.assert_called_once() + # Verify plugin hooks were called + assert tool_service._plugin_manager.invoke_hook.call_count == 2 # Pre and post invoke # Verify result assert result.content[0].text == '{\n "result": "original response"\n}' @@ -2301,9 +2311,18 @@ async def test_invoke_tool_with_plugin_post_invoke_modified_payload(self, tool_s mock_post_result.violation = None mock_post_result.modified_payload = mock_modified_payload + # First-Party + from mcpgateway.plugins.framework import PluginResult, ToolHookType + tool_service._plugin_manager = Mock() - tool_service._plugin_manager.tool_pre_invoke = AsyncMock(return_value=(Mock(continue_processing=True, violation=None, modified_payload=None), None)) - tool_service._plugin_manager.tool_post_invoke = AsyncMock(return_value=(mock_post_result, None)) + + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == ToolHookType.TOOL_PRE_INVOKE: + return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) + # POST_INVOKE + return (mock_post_result, None) + + tool_service._plugin_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), @@ -2311,8 +2330,8 @@ async def test_invoke_tool_with_plugin_post_invoke_modified_payload(self, tool_s ): result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) - # Verify plugin post-invoke was called - tool_service._plugin_manager.tool_post_invoke.assert_called_once() + # Verify plugin hooks were called + assert tool_service._plugin_manager.invoke_hook.call_count == 2 # Pre and post invoke # Verify result was modified by plugin assert result.content[0].text == "Modified by plugin" @@ -2345,9 +2364,19 @@ async def test_invoke_tool_with_plugin_post_invoke_invalid_modified_payload(self mock_post_result.violation = None mock_post_result.modified_payload = mock_modified_payload + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.framework import ToolHookType + tool_service._plugin_manager = Mock() - tool_service._plugin_manager.tool_pre_invoke = AsyncMock(return_value=(Mock(continue_processing=True, violation=None, modified_payload=None), None)) - tool_service._plugin_manager.tool_post_invoke = AsyncMock(return_value=(mock_post_result, None)) + + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == ToolHookType.TOOL_PRE_INVOKE: + return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) + # POST_INVOKE + return (mock_post_result, None) + + tool_service._plugin_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), @@ -2355,8 +2384,8 @@ async def test_invoke_tool_with_plugin_post_invoke_invalid_modified_payload(self ): result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) - # Verify plugin post-invoke was called - tool_service._plugin_manager.tool_post_invoke.assert_called_once() + # Verify plugin hooks were called + assert tool_service._plugin_manager.invoke_hook.call_count == 2 # Pre and post invoke # Verify result was converted to string since format was invalid assert result.content[0].text == "Invalid format - not a dict" @@ -2380,10 +2409,20 @@ async def test_invoke_tool_with_plugin_post_invoke_error_fail_on_error(self, too mock_response.json = Mock(return_value={"result": "original response"}) tool_service._http_client.request.return_value = mock_response - # Mock plugin manager and post-invoke hook with error + # Mock plugin manager with invoke_hook that raises error on POST_INVOKE + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.framework import ToolHookType + tool_service._plugin_manager = Mock() - tool_service._plugin_manager.tool_pre_invoke = AsyncMock(return_value=(Mock(continue_processing=True, violation=None, modified_payload=None), None)) - tool_service._plugin_manager.tool_post_invoke = AsyncMock(side_effect=Exception("Plugin error")) + + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == ToolHookType.TOOL_PRE_INVOKE: + return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) + # POST_INVOKE - raise error + raise Exception("Plugin error") + + tool_service._plugin_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) # Mock plugin config to fail on errors mock_plugin_settings = Mock() diff --git a/tests/unit/mcpgateway/test_discovery.py b/tests/unit/mcpgateway/test_discovery.py index 188360081..398e9f7f4 100644 --- a/tests/unit/mcpgateway/test_discovery.py +++ b/tests/unit/mcpgateway/test_discovery.py @@ -37,7 +37,7 @@ async def discovery(): async def _fake_gateway_info(url: str): # noqa: D401, ANN001 # Return an *empty* capabilities object - structure is unimportant here. # First-Party - from mcpgateway.models import ServerCapabilities + from mcpgateway.common.models import ServerCapabilities return ServerCapabilities() diff --git a/tests/unit/mcpgateway/test_final_coverage_push.py b/tests/unit/mcpgateway/test_final_coverage_push.py index 4620ed14c..7b970137a 100644 --- a/tests/unit/mcpgateway/test_final_coverage_push.py +++ b/tests/unit/mcpgateway/test_final_coverage_push.py @@ -16,7 +16,7 @@ import pytest # First-Party -from mcpgateway.models import ImageContent, LogLevel, ResourceContent, Role, TextContent +from mcpgateway.common.models import ImageContent, LogLevel, ResourceContent, Role, TextContent from mcpgateway.schemas import BaseModelWithConfigDict diff --git a/tests/unit/mcpgateway/test_main.py b/tests/unit/mcpgateway/test_main.py index 0dd2bddb6..3a6d8a295 100644 --- a/tests/unit/mcpgateway/test_main.py +++ b/tests/unit/mcpgateway/test_main.py @@ -24,7 +24,7 @@ # First-Party from mcpgateway.config import settings -from mcpgateway.models import InitializeResult, ResourceContent, ServerCapabilities +from mcpgateway.common.models import InitializeResult, ResourceContent, ServerCapabilities from mcpgateway.schemas import ( PromptRead, ResourceRead, @@ -1034,7 +1034,7 @@ class TestRootEndpoints: def test_list_roots_endpoint(self, mock_list, test_client, auth_headers): """Test listing all registered roots.""" # First-Party - from mcpgateway.models import Root + from mcpgateway.common.models import Root mock_list.return_value = [Root(uri="file:///test", name="Test Root")] # valid URI response = test_client.get("/roots/", headers=auth_headers) @@ -1048,7 +1048,7 @@ def test_list_roots_endpoint(self, mock_list, test_client, auth_headers): def test_add_root_endpoint(self, mock_add, test_client, auth_headers): """Test adding a new root directory.""" # First-Party - from mcpgateway.models import Root + from mcpgateway.common.models import Root mock_add.return_value = Root(uri="file:///test", name="Test Root") # valid URI diff --git a/tests/unit/mcpgateway/test_models.py b/tests/unit/mcpgateway/test_models.py index f0279d6a4..50f6989f0 100644 --- a/tests/unit/mcpgateway/test_models.py +++ b/tests/unit/mcpgateway/test_models.py @@ -18,7 +18,7 @@ import pytest # First-Party -from mcpgateway.models import ( +from mcpgateway.common.models import ( ClientCapabilities, CreateMessageResult, ImageContent, diff --git a/tests/unit/mcpgateway/test_rpc_tool_invocation.py b/tests/unit/mcpgateway/test_rpc_tool_invocation.py index 59378c518..1834f6e7b 100644 --- a/tests/unit/mcpgateway/test_rpc_tool_invocation.py +++ b/tests/unit/mcpgateway/test_rpc_tool_invocation.py @@ -17,7 +17,7 @@ # First-Party from mcpgateway.main import app -from mcpgateway.models import Tool +from mcpgateway.common.models import Tool from mcpgateway.services.tool_service import ToolService diff --git a/tests/unit/mcpgateway/test_schemas.py b/tests/unit/mcpgateway/test_schemas.py index bd357c781..1d782d44e 100644 --- a/tests/unit/mcpgateway/test_schemas.py +++ b/tests/unit/mcpgateway/test_schemas.py @@ -20,7 +20,7 @@ import pytest # First-Party -from mcpgateway.models import ( +from mcpgateway.common.models import ( ClientCapabilities, CreateMessageResult, ImageContent, diff --git a/tests/unit/mcpgateway/validation/test_validators.py b/tests/unit/mcpgateway/validation/test_validators.py index ccb574db5..e2f930026 100644 --- a/tests/unit/mcpgateway/validation/test_validators.py +++ b/tests/unit/mcpgateway/validation/test_validators.py @@ -15,7 +15,7 @@ import pytest # First-Party -from mcpgateway.validators import SecurityValidator +from mcpgateway.common.validators import SecurityValidator class DummySettings: @@ -48,7 +48,7 @@ def logfn(*args, **kwargs): return logfn - monkeypatch.setattr("mcpgateway.validators.logger", DummyLogger()) + monkeypatch.setattr("mcpgateway.common.validators.logger", DummyLogger()) yield logs diff --git a/tests/unit/mcpgateway/validation/test_validators_advanced.py b/tests/unit/mcpgateway/validation/test_validators_advanced.py index 82eaf75f6..e29830c45 100644 --- a/tests/unit/mcpgateway/validation/test_validators_advanced.py +++ b/tests/unit/mcpgateway/validation/test_validators_advanced.py @@ -27,7 +27,7 @@ import pytest # First-Party -from mcpgateway.validators import SecurityValidator +from mcpgateway.common.validators import SecurityValidator class DummySettings: @@ -84,7 +84,7 @@ def logfn(*args, **kwargs): return logfn - monkeypatch.setattr("mcpgateway.validators.logger", DummyLogger()) + monkeypatch.setattr("mcpgateway.common.validators.logger", DummyLogger()) yield logs diff --git a/uv.lock b/uv.lock index 89e4f8a8e..a361a4d4a 100644 --- a/uv.lock +++ b/uv.lock @@ -4651,8 +4651,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/89/3fdb5902bdab8868bbedc1c6e6023a4e08112ceac5db97fc2012060e0c9a/psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2e164359396576a3cc701ba8af4751ae68a07235d7a380c631184a611220d9a4", size = 4410955, upload-time = "2025-10-10T11:11:21.21Z" }, { url = "https://files.pythonhosted.org/packages/ce/24/e18339c407a13c72b336e0d9013fbbbde77b6fd13e853979019a1269519c/psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d57c9c387660b8893093459738b6abddbb30a7eab058b77b0d0d1c7d521ddfd7", size = 4468007, upload-time = "2025-10-10T11:11:24.831Z" }, { url = "https://files.pythonhosted.org/packages/91/7e/b8441e831a0f16c159b5381698f9f7f7ed54b77d57bc9c5f99144cc78232/psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2c226ef95eb2250974bf6fa7a842082b31f68385c4f3268370e3f3870e7859ee", size = 4165012, upload-time = "2025-10-10T11:11:29.51Z" }, + { url = "https://files.pythonhosted.org/packages/0d/61/4aa89eeb6d751f05178a13da95516c036e27468c5d4d2509bb1e15341c81/psycopg2_binary-2.9.11-cp311-cp311-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a311f1edc9967723d3511ea7d2708e2c3592e3405677bf53d5c7246753591fbb", size = 3981881, upload-time = "2025-10-30T02:55:07.332Z" }, { url = "https://files.pythonhosted.org/packages/76/a1/2f5841cae4c635a9459fe7aca8ed771336e9383b6429e05c01267b0774cf/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ebb415404821b6d1c47353ebe9c8645967a5235e6d88f914147e7fd411419e6f", size = 3650985, upload-time = "2025-10-10T11:11:34.975Z" }, { url = "https://files.pythonhosted.org/packages/84/74/4defcac9d002bca5709951b975173c8c2fa968e1a95dc713f61b3a8d3b6a/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f07c9c4a5093258a03b28fab9b4f151aa376989e7f35f855088234e656ee6a94", size = 3296039, upload-time = "2025-10-10T11:11:40.432Z" }, + { url = "https://files.pythonhosted.org/packages/6d/c2/782a3c64403d8ce35b5c50e1b684412cf94f171dc18111be8c976abd2de1/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:00ce1830d971f43b667abe4a56e42c1e2d594b32da4802e44a73bacacb25535f", size = 3043477, upload-time = "2025-10-30T02:55:11.182Z" }, { url = "https://files.pythonhosted.org/packages/c8/31/36a1d8e702aa35c38fc117c2b8be3f182613faa25d794b8aeaab948d4c03/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cffe9d7697ae7456649617e8bb8d7a45afb71cd13f7ab22af3e5c61f04840908", size = 3345842, upload-time = "2025-10-10T11:11:45.366Z" }, { url = "https://files.pythonhosted.org/packages/6e/b4/a5375cda5b54cb95ee9b836930fea30ae5a8f14aa97da7821722323d979b/psycopg2_binary-2.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:304fd7b7f97eef30e91b8f7e720b3db75fee010b520e434ea35ed1ff22501d03", size = 2713894, upload-time = "2025-10-10T11:11:48.775Z" }, { url = "https://files.pythonhosted.org/packages/d8/91/f870a02f51be4a65987b45a7de4c2e1897dd0d01051e2b559a38fa634e3e/psycopg2_binary-2.9.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:be9b840ac0525a283a96b556616f5b4820e0526addb8dcf6525a0fa162730be4", size = 3756603, upload-time = "2025-10-10T11:11:52.213Z" }, @@ -4660,8 +4662,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/75/364847b879eb630b3ac8293798e380e441a957c53657995053c5ec39a316/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ab8905b5dcb05bf3fb22e0cf90e10f469563486ffb6a96569e51f897c750a76a", size = 4411159, upload-time = "2025-10-10T11:12:00.49Z" }, { url = "https://files.pythonhosted.org/packages/6f/a0/567f7ea38b6e1c62aafd58375665a547c00c608a471620c0edc364733e13/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:bf940cd7e7fec19181fdbc29d76911741153d51cab52e5c21165f3262125685e", size = 4468234, upload-time = "2025-10-10T11:12:04.892Z" }, { url = "https://files.pythonhosted.org/packages/30/da/4e42788fb811bbbfd7b7f045570c062f49e350e1d1f3df056c3fb5763353/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fa0f693d3c68ae925966f0b14b8edda71696608039f4ed61b1fe9ffa468d16db", size = 4166236, upload-time = "2025-10-10T11:12:11.674Z" }, + { url = "https://files.pythonhosted.org/packages/3c/94/c1777c355bc560992af848d98216148be5f1be001af06e06fc49cbded578/psycopg2_binary-2.9.11-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a1cf393f1cdaf6a9b57c0a719a1068ba1069f022a59b8b1fe44b006745b59757", size = 3983083, upload-time = "2025-10-30T02:55:15.73Z" }, { url = "https://files.pythonhosted.org/packages/bd/42/c9a21edf0e3daa7825ed04a4a8588686c6c14904344344a039556d78aa58/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ef7a6beb4beaa62f88592ccc65df20328029d721db309cb3250b0aae0fa146c3", size = 3652281, upload-time = "2025-10-10T11:12:17.713Z" }, { url = "https://files.pythonhosted.org/packages/12/22/dedfbcfa97917982301496b6b5e5e6c5531d1f35dd2b488b08d1ebc52482/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:31b32c457a6025e74d233957cc9736742ac5a6cb196c6b68499f6bb51390bd6a", size = 3298010, upload-time = "2025-10-10T11:12:22.671Z" }, + { url = "https://files.pythonhosted.org/packages/66/ea/d3390e6696276078bd01b2ece417deac954dfdd552d2edc3d03204416c0c/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:edcb3aeb11cb4bf13a2af3c53a15b3d612edeb6409047ea0b5d6a21a9d744b34", size = 3044641, upload-time = "2025-10-30T02:55:19.929Z" }, { url = "https://files.pythonhosted.org/packages/12/9a/0402ded6cbd321da0c0ba7d34dc12b29b14f5764c2fc10750daa38e825fc/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:62b6d93d7c0b61a1dd6197d208ab613eb7dcfdcca0a49c42ceb082257991de9d", size = 3347940, upload-time = "2025-10-10T11:12:26.529Z" }, { url = "https://files.pythonhosted.org/packages/b1/d2/99b55e85832ccde77b211738ff3925a5d73ad183c0b37bcbbe5a8ff04978/psycopg2_binary-2.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:b33fabeb1fde21180479b2d4667e994de7bbf0eec22832ba5d9b5e4cf65b6c6d", size = 2714147, upload-time = "2025-10-10T11:12:29.535Z" }, { url = "https://files.pythonhosted.org/packages/ff/a8/a2709681b3ac11b0b1786def10006b8995125ba268c9a54bea6f5ae8bd3e/psycopg2_binary-2.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b8fb3db325435d34235b044b199e56cdf9ff41223a4b9752e8576465170bb38c", size = 3756572, upload-time = "2025-10-10T11:12:32.873Z" }, @@ -4669,8 +4673,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/32/b2ffe8f3853c181e88f0a157c5fb4e383102238d73c52ac6d93a5c8bffe6/psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8c55b385daa2f92cb64b12ec4536c66954ac53654c7f15a203578da4e78105c0", size = 4411242, upload-time = "2025-10-10T11:12:42.388Z" }, { url = "https://files.pythonhosted.org/packages/10/04/6ca7477e6160ae258dc96f67c371157776564679aefd247b66f4661501a2/psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c0377174bf1dd416993d16edc15357f6eb17ac998244cca19bc67cdc0e2e5766", size = 4468258, upload-time = "2025-10-10T11:12:48.654Z" }, { url = "https://files.pythonhosted.org/packages/3c/7e/6a1a38f86412df101435809f225d57c1a021307dd0689f7a5e7fe83588b1/psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5c6ff3335ce08c75afaed19e08699e8aacf95d4a260b495a4a8545244fe2ceb3", size = 4166295, upload-time = "2025-10-10T11:12:52.525Z" }, + { url = "https://files.pythonhosted.org/packages/f2/7d/c07374c501b45f3579a9eb761cbf2604ddef3d96ad48679112c2c5aa9c25/psycopg2_binary-2.9.11-cp313-cp313-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:84011ba3109e06ac412f95399b704d3d6950e386b7994475b231cf61eec2fc1f", size = 3983133, upload-time = "2025-10-30T02:55:24.329Z" }, { url = "https://files.pythonhosted.org/packages/82/56/993b7104cb8345ad7d4516538ccf8f0d0ac640b1ebd8c754a7b024e76878/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ba34475ceb08cccbdd98f6b46916917ae6eeb92b5ae111df10b544c3a4621dc4", size = 3652383, upload-time = "2025-10-10T11:12:56.387Z" }, { url = "https://files.pythonhosted.org/packages/2d/ac/eaeb6029362fd8d454a27374d84c6866c82c33bfc24587b4face5a8e43ef/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b31e90fdd0f968c2de3b26ab014314fe814225b6c324f770952f7d38abf17e3c", size = 3298168, upload-time = "2025-10-10T11:13:00.403Z" }, + { url = "https://files.pythonhosted.org/packages/2b/39/50c3facc66bded9ada5cbc0de867499a703dc6bca6be03070b4e3b65da6c/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:d526864e0f67f74937a8fce859bd56c979f5e2ec57ca7c627f5f1071ef7fee60", size = 3044712, upload-time = "2025-10-30T02:55:27.975Z" }, { url = "https://files.pythonhosted.org/packages/9c/8e/b7de019a1f562f72ada81081a12823d3c1590bedc48d7d2559410a2763fe/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04195548662fa544626c8ea0f06561eb6203f1984ba5b4562764fbeb4c3d14b1", size = 3347549, upload-time = "2025-10-10T11:13:03.971Z" }, { url = "https://files.pythonhosted.org/packages/80/2d/1bb683f64737bbb1f86c82b7359db1eb2be4e2c0c13b947f80efefa7d3e5/psycopg2_binary-2.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:efff12b432179443f54e230fdf60de1f6cc726b6c832db8701227d089310e8aa", size = 2714215, upload-time = "2025-10-10T11:13:07.14Z" }, ]