diff --git a/.claude/implementations/phase3.5_http_transport.md b/.claude/implementations/phase3.5_http_transport.md new file mode 100644 index 0000000..92c79f8 --- /dev/null +++ b/.claude/implementations/phase3.5_http_transport.md @@ -0,0 +1,384 @@ +# Phase 3.5: HTTP Transport with SSE Implementation Plan + +## Overview + +Phase 3.5 adds HTTP transport alongside the existing stdio transport, enabling the MCP server to handle web-based clients with real-time streaming via Server-Sent Events (SSE). This provides a production-ready HTTP API while maintaining full compatibility with all existing tools. + +## Timeline: 3-4 days + +## Implementation Structure + +``` +contextframe/mcp/transports/ +├── __init__.py +├── stdio.py # Existing stdio adapter +└── http/ + ├── __init__.py + ├── adapter.py # HttpAdapter implementation + ├── server.py # FastAPI/Starlette server + ├── sse.py # Server-Sent Events + ├── auth.py # OAuth 2.1 authentication + ├── session.py # Session management + └── security.py # CORS, rate limiting +``` + +## Core Components + +### 1. HttpAdapter Class + +The `HttpAdapter` extends `TransportAdapter` to provide HTTP-specific functionality: + +```python +class HttpAdapter(TransportAdapter): + """HTTP transport adapter with SSE streaming support.""" + + def __init__(self, app: FastAPI): + super().__init__() + self.app = app + self._active_streams: Dict[str, SSEStream] = {} + + async def send_progress(self, progress: Progress) -> None: + """Send progress via SSE to all active streams.""" + # Stream progress updates in real-time + + async def handle_subscription(self, subscription: Subscription) -> AsyncIterator[Dict[str, Any]]: + """Stream changes via SSE.""" + # Real-time change streaming +``` + +### 2. FastAPI/Starlette Server + +Main HTTP server implementation: + +```python +# server.py +from fastapi import FastAPI, HTTPException, Depends +from fastapi.middleware.cors import CORSMiddleware +from sse_starlette.sse import EventSourceResponse + +app = FastAPI(title="ContextFrame MCP Server") + +# MCP endpoints +@app.post("/mcp/v1/initialize") +@app.post("/mcp/v1/tools/list") +@app.post("/mcp/v1/tools/call") +@app.post("/mcp/v1/resources/list") +@app.post("/mcp/v1/resources/read") + +# SSE endpoints +@app.get("/mcp/v1/sse/subscribe") +@app.get("/mcp/v1/sse/progress/{operation_id}") +``` + +### 3. SSE Implementation + +Real-time streaming for progress and subscriptions: + +```python +# sse.py +class SSEStream: + """Manages an SSE connection for streaming updates.""" + + async def send_event(self, event_type: str, data: Any): + """Send an SSE event.""" + + async def stream_progress(self, operation_id: str): + """Stream progress updates for an operation.""" + + async def stream_changes(self, subscription_id: str): + """Stream dataset changes for a subscription.""" +``` + +### 4. Authentication & Security + +OAuth 2.1 with PKCE for secure API access: + +```python +# auth.py +class OAuth2Handler: + """OAuth 2.1 authentication with PKCE.""" + + async def verify_token(self, token: str) -> Dict[str, Any]: + """Verify and decode JWT token.""" + + async def check_permissions(self, user: Dict, resource: str, action: str) -> bool: + """Check user permissions for resource access.""" +``` + +## Key Features + +### 1. Transport Selection + +Server can run with either or both transports: + +```python +# Run with stdio only (default) +python -m contextframe.mcp dataset.lance + +# Run with HTTP only +python -m contextframe.mcp dataset.lance --transport http --port 8080 + +# Run with both transports +python -m contextframe.mcp dataset.lance --transport both --port 8080 +``` + +### 2. Streaming Capabilities + +- **Progress Updates**: Real-time operation progress via SSE +- **Subscriptions**: Live dataset changes streamed to clients +- **Batch Operations**: Stream results as they complete +- **Long-Running Operations**: Keep clients updated + +### 3. HTTP-Specific Features + +- **RESTful Endpoints**: Standard HTTP API design +- **Request/Response**: JSON for all non-streaming operations +- **File Uploads**: Support for document upload via multipart +- **Health Checks**: `/health` and `/ready` endpoints +- **API Documentation**: Auto-generated OpenAPI/Swagger + +### 4. Security Features + +- **Authentication**: OAuth 2.1 with JWT tokens +- **Authorization**: Resource-level permissions +- **CORS**: Configurable cross-origin policies +- **Rate Limiting**: Per-user and global limits +- **Request Validation**: Input sanitization + +## Integration Points + +### 1. Unified Tool Handling + +All existing tools work identically: + +```python +# server.py updates +async def create_server(dataset_path: str, transport: str = "stdio"): + if transport == "http" or transport == "both": + http_adapter = HttpAdapter(app) + handler = MessageHandler(dataset, http_adapter) + + @app.post("/mcp/v1/tools/call") + async def call_tool(request: ToolCallRequest): + return await handler.handle_tool_call(request.dict()) +``` + +### 2. Configuration + +Extended configuration for HTTP: + +```json +{ + "transport": { + "type": "http", + "http": { + "host": "0.0.0.0", + "port": 8080, + "cors": { + "origins": ["*"], + "credentials": true + }, + "auth": { + "enabled": true, + "issuer": "https://auth.example.com", + "audience": "contextframe-mcp" + }, + "rate_limit": { + "requests_per_minute": 60, + "burst": 10 + }, + "ssl": { + "enabled": false, + "cert": "/path/to/cert.pem", + "key": "/path/to/key.pem" + } + } + } +} +``` + +## API Endpoints + +### Core MCP Endpoints + +``` +POST /mcp/v1/initialize # Initialize session +POST /mcp/v1/tools/list # List available tools +POST /mcp/v1/tools/call # Call a tool +POST /mcp/v1/resources/list # List resources +POST /mcp/v1/resources/read # Read a resource +``` + +### Streaming Endpoints + +``` +GET /mcp/v1/sse/subscribe # Subscribe to changes (SSE) +GET /mcp/v1/sse/progress/:id # Stream operation progress (SSE) +POST /mcp/v1/subscriptions # Manage subscriptions +``` + +### Utility Endpoints + +``` +GET /health # Health check +GET /ready # Readiness check +GET /metrics # Prometheus metrics +GET /openapi.json # OpenAPI specification +``` + +## Example Usage + +### HTTP Client Example + +```python +import httpx +import asyncio +from httpx_sse import aconnect_sse + +# Initialize client +async with httpx.AsyncClient() as client: + # Authenticate + token = await authenticate(client) + headers = {"Authorization": f"Bearer {token}"} + + # Call a tool + response = await client.post( + "http://localhost:8080/mcp/v1/tools/call", + json={ + "name": "search_documents", + "arguments": {"query": "machine learning"} + }, + headers=headers + ) + + # Stream progress for batch operation + batch_response = await client.post( + "http://localhost:8080/mcp/v1/tools/call", + json={ + "name": "batch_enhance", + "arguments": {"documents": [...]} + }, + headers=headers + ) + + operation_id = batch_response.json()["operation_id"] + + # Stream progress via SSE + async with aconnect_sse( + client, + "GET", + f"http://localhost:8080/mcp/v1/sse/progress/{operation_id}", + headers=headers + ) as event_source: + async for sse in event_source.aiter_sse(): + print(f"Progress: {sse.data}") +``` + +### JavaScript/Browser Example + +```javascript +// Subscribe to changes +const eventSource = new EventSource( + 'http://localhost:8080/mcp/v1/sse/subscribe?resource_type=documents', + { headers: { 'Authorization': `Bearer ${token}` } } +); + +eventSource.onmessage = (event) => { + const change = JSON.parse(event.data); + console.log('Document changed:', change); +}; + +// Call a tool +const response = await fetch('http://localhost:8080/mcp/v1/tools/call', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${token}` + }, + body: JSON.stringify({ + name: 'add_document', + arguments: { + content: 'Document content', + metadata: { title: 'New Document' } + } + }) +}); +``` + +## Testing Strategy + +### 1. Unit Tests + +- HttpAdapter methods +- SSE streaming functionality +- Authentication/authorization +- Rate limiting logic + +### 2. Integration Tests + +- Full HTTP request/response cycle +- SSE connection management +- Tool execution via HTTP +- Concurrent request handling + +### 3. Performance Tests + +- Throughput benchmarking +- SSE connection scaling +- Memory usage under load +- Latency measurements + +## Migration Guide + +For users transitioning from stdio to HTTP: + +1. **Configuration**: Update transport settings +2. **Authentication**: Obtain OAuth tokens +3. **Client Updates**: Use HTTP client libraries +4. **Streaming**: Switch from polling to SSE +5. **Error Handling**: Handle HTTP status codes + +## Success Criteria + +1. **Feature Parity** + - All 43 tools work via HTTP + - Identical results to stdio transport + - No breaking changes + +2. **Performance** + - <50ms latency for simple operations + - Support 1000+ concurrent SSE connections + - Efficient streaming for large responses + +3. **Security** + - OAuth 2.1 compliance + - Secure by default configuration + - Comprehensive audit logging + +4. **Developer Experience** + - Clear API documentation + - SDK examples in multiple languages + - Migration guides and tutorials + +## Dependencies + +```toml +# pyproject.toml additions +[tool.poetry.dependencies] +fastapi = "^0.104.0" +uvicorn = { version = "^0.24.0", extras = ["standard"] } +sse-starlette = "^1.6.0" +python-jose = { version = "^3.3.0", extras = ["cryptography"] } +httpx = "^0.25.0" +slowapi = "^0.1.9" # Rate limiting +python-multipart = "^0.0.6" # File uploads +``` + +## Next Steps After Phase 3.5 + +With HTTP transport complete: +1. Finalize Phase 3.6 (Analytics & Performance tools) +2. Production deployment guides +3. Kubernetes manifests and Helm charts +4. Monitoring and observability setup +5. Performance optimization and caching \ No newline at end of file diff --git a/.claude/implementations/phase3.6_analytics_performance.md b/.claude/implementations/phase3.6_analytics_performance.md new file mode 100644 index 0000000..0c5a0e1 --- /dev/null +++ b/.claude/implementations/phase3.6_analytics_performance.md @@ -0,0 +1,285 @@ +# Phase 3.6: Analytics & Performance Tools Implementation Plan + +## Overview + +Phase 3.6 adds 8 analytics and performance tools to provide insights into dataset usage, query patterns, and optimization opportunities. These tools help users understand and optimize their ContextFrame deployments. + +## Timeline: 2-3 days + +## Implementation Structure + +``` +contextframe/mcp/analytics/ +├── __init__.py +├── stats.py # Dataset statistics and metrics +├── analyzer.py # Query and performance analysis +├── optimizer.py # Storage and index optimization +└── tools.py # Tool implementations +``` + +## Tools to Implement + +### 1. Analytics Tools (4 tools) + +#### `get_dataset_stats` +- **Purpose**: Comprehensive dataset statistics +- **Returns**: + - Total documents, collections, size + - Document type distribution + - Metadata field usage + - Embedding coverage + - Version history stats + - Relationship graph metrics + +#### `analyze_usage` +- **Purpose**: Usage patterns and access analytics +- **Parameters**: + - `time_range`: Period to analyze + - `group_by`: Hour, day, week + - `include_queries`: Include query patterns +- **Returns**: + - Access frequency by document/collection + - Query patterns and types + - Popular search terms + - User access patterns (if tracked) + +#### `query_performance` +- **Purpose**: Analyze query performance +- **Parameters**: + - `time_range`: Analysis period + - `query_type`: vector, text, hybrid, filter + - `min_duration`: Minimum query time to include +- **Returns**: + - Slow queries with explanations + - Query type distribution + - Filter efficiency analysis + - Optimization recommendations + +#### `relationship_analysis` +- **Purpose**: Analyze document relationships +- **Parameters**: + - `max_depth`: How deep to traverse + - `relationship_types`: Types to analyze + - `include_orphans`: Include unconnected docs +- **Returns**: + - Relationship graph statistics + - Clustering coefficients + - Connected components + - Orphaned documents + - Circular dependencies + +### 2. Performance Tools (4 tools) + +#### `optimize_storage` +- **Purpose**: Optimize Lance dataset storage +- **Parameters**: + - `operations`: ["compact", "vacuum", "reindex"] + - `dry_run`: Preview changes + - `target_version`: Optimize to specific version +- **Returns**: + - Space reclaimed + - Fragments consolidated + - Performance improvements + - Optimization log + +#### `index_recommendations` +- **Purpose**: Suggest index improvements +- **Parameters**: + - `analyze_queries`: Recent queries to analyze + - `workload_type`: "search", "analytics", "mixed" +- **Returns**: + - Missing index suggestions + - Redundant index identification + - Index usage statistics + - Implementation commands + +#### `benchmark_operations` +- **Purpose**: Benchmark key operations +- **Parameters**: + - `operations`: ["search", "insert", "update", "scan"] + - `sample_size`: Number of operations + - `concurrency`: Parallel operations +- **Returns**: + - Operation latencies (p50, p90, p99) + - Throughput metrics + - Resource utilization + - Comparison with baselines + +#### `export_metrics` +- **Purpose**: Export analytics for monitoring +- **Parameters**: + - `format`: "prometheus", "json", "csv" + - `metrics`: Specific metrics to export + - `labels`: Additional labels +- **Returns**: + - Formatted metrics + - Timestamp + - Ready for monitoring systems + +## Architecture Patterns + +### 1. StatsCollector Base Class +```python +class StatsCollector: + """Base class for collecting dataset statistics.""" + + async def collect_stats(self) -> Dict[str, Any]: + """Collect all statistics.""" + stats = { + "basic": await self._collect_basic_stats(), + "content": await self._collect_content_stats(), + "performance": await self._collect_performance_stats(), + "relationships": await self._collect_relationship_stats() + } + return stats +``` + +### 2. Query Analyzer +```python +class QueryAnalyzer: + """Analyze query patterns and performance.""" + + def __init__(self, dataset: FrameDataset): + self.dataset = dataset + self.query_log = [] # In production, use persistent storage + + async def analyze_query(self, query: str, duration: float): + """Analyze a single query.""" + # Extract query features + # Identify optimization opportunities + # Track patterns +``` + +### 3. Storage Optimizer +```python +class StorageOptimizer: + """Optimize Lance dataset storage.""" + + async def compact(self, dry_run: bool = True): + """Compact dataset fragments.""" + # Use Lance's optimize operations + # Track space savings + # Report improvements +``` + +## Integration Points + +### 1. With Existing Tools +- Query performance integrates with search tools +- Storage optimization works with batch operations +- Analytics complement subscription monitoring + +### 2. Transport Considerations +- All tools work with stdio transport +- Ready for HTTP streaming when added +- Metrics export designed for monitoring integration + +## Testing Strategy + +### 1. Unit Tests +- Mock dataset statistics +- Test metric calculations +- Verify optimization logic + +### 2. Integration Tests +- Real dataset analysis +- Performance benchmarking +- Storage optimization verification + +### 3. Performance Tests +- Analytics overhead measurement +- Optimization impact testing + +## Example Usage + +### Get Dataset Statistics +```python +result = await get_dataset_stats({ + "include_details": True, + "calculate_sizes": True +}) + +# Returns: +{ + "total_documents": 10000, + "total_collections": 25, + "storage_size_mb": 256.5, + "avg_document_size": 2650, + "embedding_coverage": 0.95, + "version_count": 142, + "relationships": { + "total": 3500, + "types": { + "child_of": 2000, + "references": 1500 + } + } +} +``` + +### Analyze Query Performance +```python +result = await query_performance({ + "time_range": "7d", + "min_duration": 100 # ms +}) + +# Returns: +{ + "slow_queries": [ + { + "query": "complex filter expression", + "avg_duration_ms": 250, + "count": 45, + "recommendation": "Add index on metadata.category" + } + ], + "query_distribution": { + "vector": 0.6, + "text": 0.3, + "hybrid": 0.1 + } +} +``` + +### Optimize Storage +```python +result = await optimize_storage({ + "operations": ["compact", "vacuum"], + "dry_run": False +}) + +# Returns: +{ + "space_reclaimed_mb": 45.2, + "fragments_before": 150, + "fragments_after": 12, + "duration_seconds": 8.5, + "version_created": 143 +} +``` + +## Success Criteria + +1. **Comprehensive Analytics** + - Complete dataset insights + - Actionable recommendations + - Performance visibility + +2. **Effective Optimization** + - Measurable performance gains + - Safe storage operations + - Clear improvement metrics + +3. **Tool Integration** + - Works with all transports + - Complements existing tools + - Production-ready metrics + +## Next Steps After Phase 3.6 + +With all 43 tools complete: +1. Phase 3.5: Add HTTP transport with SSE +2. Performance testing at scale +3. Production deployment guides +4. Integration examples \ No newline at end of file diff --git a/contextframe/frame.py b/contextframe/frame.py index e5211d6..590419c 100644 --- a/contextframe/frame.py +++ b/contextframe/frame.py @@ -703,12 +703,11 @@ def open( """ raw_uri = str(path) from lance.dataset import LanceDataset + if storage_options is None: ds = LanceDataset(raw_uri, version=version) else: - ds = LanceDataset( - raw_uri, version=version, storage_options=storage_options - ) + ds = LanceDataset(raw_uri, version=version, storage_options=storage_options) return cls(ds) # ------------------------------------------------------------------ @@ -864,11 +863,11 @@ def enrich( **enricher_kwargs, ) -> list[Any]: """Enrich documents in the dataset using LLM-powered analysis. - + This convenience method provides easy access to the enrichment functionality, allowing AI agents and users to populate schema fields with meaningful metadata. - + Parameters ---------- enrichments: @@ -893,12 +892,12 @@ def enrich( LLM model to use for enrichment **enricher_kwargs: Additional arguments for ContextEnricher - + Returns ------- list[EnrichmentResult] Results of the enrichment operations - + Examples -------- >>> # Basic enrichment @@ -906,7 +905,7 @@ def enrich( ... "context": "Explain what this document teaches", ... "tags": "Extract technology and concept tags" ... }) - + >>> # Custom metadata extraction >>> dataset.enrich({ ... "custom_metadata": { @@ -914,7 +913,7 @@ def enrich( ... "format": "json" ... } ... }, filter="context IS NULL") - + >>> # Using different model >>> dataset.enrich( ... {"context": "Summarize for AI developers"}, @@ -922,7 +921,7 @@ def enrich( ... ) """ from contextframe.enrich import ContextEnricher - + enricher = ContextEnricher(model=model, **enricher_kwargs) return enricher.enrich_dataset( self, @@ -930,7 +929,7 @@ def enrich( filter=filter, skip_existing=skip_existing, batch_size=batch_size, - show_progress=show_progress + show_progress=show_progress, ) # ------------------------------------------------------------------ @@ -1967,7 +1966,7 @@ def create_scalar_index(self, column: str, *, replace: bool = True) -> None: ) # Delegate to Lance self._native.create_scalar_index(column, replace=replace) - + def enhance( self, enhancements: dict[str, str | dict[str, Any]], @@ -1977,12 +1976,12 @@ def enhance( show_progress: bool = True, provider: str = "openai", model: str = "gpt-4o-mini", - **kwargs + **kwargs, ) -> list[Any]: """Enhance documents in the dataset with LLM-generated metadata. - + This is a convenience method that wraps ContextEnhancer functionality. - + Args: enhancements: Map of field_name -> prompt or config dict filter: Optional Lance SQL filter @@ -1992,10 +1991,10 @@ def enhance( provider: LLM provider (openai, anthropic, etc.) model: Model name **kwargs: Additional provider-specific arguments - + Returns: List of enhancement results - + Example: >>> dataset.enhance({ ... "context": "Summarize what this document teaches", @@ -2007,7 +2006,7 @@ def enhance( ... }) """ from contextframe.enhance import ContextEnhancer - + enhancer = ContextEnhancer(provider=provider, model=model, **kwargs) return enhancer.enhance_dataset( self, @@ -2015,5 +2014,215 @@ def enhance( filter=filter, batch_size=batch_size, skip_existing=skip_existing, - show_progress=show_progress + show_progress=show_progress, ) + + # ------------------------------------------------------------------ + # Analytics and Performance Methods + # ------------------------------------------------------------------ + + def get_dataset_stats(self) -> dict[str, Any]: + """Get comprehensive dataset statistics using Lance's native stats. + + Returns: + Dictionary containing: + - dataset_stats: Fragment counts, deleted rows, small files + - data_stats: Field-level statistics + - storage_size: Total size in bytes + - version_info: Current and latest versions + - index_info: List of indices + """ + stats = {} + + # Basic dataset stats + if hasattr(self._dataset, 'stats'): + dataset_stats = self._dataset.stats.dataset_stats() + stats['dataset_stats'] = { + 'num_fragments': dataset_stats.num_fragments, + 'num_deleted_rows': dataset_stats.num_deleted_rows, + 'num_small_files': dataset_stats.num_small_files, + } + + # Data statistics + data_stats = self._dataset.stats.data_stats() + if data_stats: + stats['data_stats'] = data_stats + + # Version info + stats['version_info'] = { + 'current_version': self._dataset.version, + 'latest_version': self._dataset.latest_version, + 'data_storage_version': self._dataset.data_storage_version, + } + + # Storage info + stats['storage'] = { + 'uri': self._dataset.uri, + 'num_rows': len(self), + } + + # Index info + if hasattr(self._dataset, 'list_indices'): + indices = self._dataset.list_indices() + stats['indices'] = [ + { + 'name': idx.name, + 'type': idx.type, + 'fields': idx.fields, + 'version': idx.version, + } + for idx in indices + ] + + return stats + + def get_fragment_stats(self) -> list[dict[str, Any]]: + """Get statistics for all fragments in the dataset. + + Returns: + List of fragment statistics including: + - fragment_id: Fragment identifier + - num_rows: Number of rows after deletions + - num_deletions: Number of deleted rows + - physical_rows: Original row count + - files: List of data files + """ + fragments = [] + + for fragment in self._dataset.get_fragments(): + metadata = fragment.metadata + fragments.append( + { + 'fragment_id': fragment.fragment_id + if hasattr(fragment, 'fragment_id') + else len(fragments), + 'num_rows': metadata.num_rows, + 'num_deletions': metadata.num_deletions + if hasattr(metadata, 'num_deletions') + else 0, + 'physical_rows': metadata.physical_rows + if hasattr(metadata, 'physical_rows') + else metadata.num_rows, + 'files': [f.path() for f in metadata.files] + if hasattr(metadata, 'files') + else [], + } + ) + + return fragments + + def compact_files( + self, target_rows_per_fragment: int = 1024 * 1024, **kwargs + ) -> dict[str, Any]: + """Compact dataset files to optimize storage. + + Args: + target_rows_per_fragment: Target number of rows per fragment + **kwargs: Additional arguments passed to Lance optimizer + + Returns: + Dictionary with compaction metrics + """ + if not hasattr(self._dataset, 'optimize'): + raise NotImplementedError( + "Dataset optimization requires newer Lance version" + ) + + # Perform compaction + metrics = self._dataset.optimize.compact_files( + target_rows_per_fragment=target_rows_per_fragment, **kwargs + ) + + return { + 'fragments_compacted': getattr(metrics, 'fragments_compacted', 0), + 'files_removed': getattr(metrics, 'files_removed', 0), + 'files_added': getattr(metrics, 'files_added', 0), + } + + def optimize_indices(self, **kwargs) -> dict[str, Any]: + """Optimize dataset indices for better query performance. + + Returns: + Dictionary with optimization results + """ + if not hasattr(self._dataset, 'optimize'): + raise NotImplementedError("Index optimization requires newer Lance version") + + # Optimize indices + self._dataset.optimize.optimize_indices(**kwargs) + + return { + 'status': 'completed', + 'indices_optimized': len(self._dataset.list_indices()) + if hasattr(self._dataset, 'list_indices') + else 0, + } + + def cleanup_old_versions( + self, older_than: _dt.timedelta | None = None + ) -> dict[str, Any]: + """Clean up old dataset versions to reclaim space. + + Args: + older_than: Only clean versions older than this duration + + Returns: + Dictionary with cleanup statistics + """ + cleanup_stats = self._dataset.cleanup_old_versions(older_than=older_than) + + return { + 'bytes_removed': getattr(cleanup_stats, 'bytes_removed', 0), + 'old_versions_removed': getattr(cleanup_stats, 'old_versions', 0), + } + + def list_indices(self) -> list[dict[str, Any]]: + """List all indices in the dataset. + + Returns: + List of index information dictionaries + """ + if not hasattr(self._dataset, 'list_indices'): + return [] + + indices = [] + for idx in self._dataset.list_indices(): + indices.append( + { + 'name': idx.name, + 'type': idx.type, + 'uuid': str(idx.uuid) if hasattr(idx, 'uuid') else None, + 'fields': idx.fields, + 'version': idx.version, + 'fragment_ids': list(idx.fragment_ids) + if hasattr(idx, 'fragment_ids') + else [], + } + ) + + return indices + + def get_version_history(self) -> list[dict[str, Any]]: + """Get version history with metadata. + + Returns: + List of version information dictionaries + """ + versions = [] + + for version in range(self._dataset.latest_version + 1): + try: + # Checkout version to get metadata + versioned_ds = self._dataset.checkout_version(version) + versions.append( + { + 'version': version, + 'num_rows': len(versioned_ds), + 'schema_fields': len(versioned_ds.schema), + } + ) + except Exception: + # Version might be cleaned up + continue + + return versions diff --git a/contextframe/mcp/analytics/__init__.py b/contextframe/mcp/analytics/__init__.py new file mode 100644 index 0000000..39601d7 --- /dev/null +++ b/contextframe/mcp/analytics/__init__.py @@ -0,0 +1,47 @@ +"""Analytics and performance tools for ContextFrame MCP server. + +This module provides tools for: +- Dataset statistics and metrics +- Usage pattern analysis +- Query performance monitoring +- Relationship graph analysis +- Storage optimization +- Index recommendations +- Operation benchmarking +- Metrics export for monitoring +""" + +from .analyzer import QueryAnalyzer, RelationshipAnalyzer, UsageAnalyzer +from .optimizer import IndexAdvisor, PerformanceBenchmark, StorageOptimizer +from .stats import DatasetStats, StatsCollector +from .tools import ( + AnalyzeUsageHandler, + BenchmarkOperationsHandler, + ExportMetricsHandler, + GetDatasetStatsHandler, + IndexRecommendationsHandler, + OptimizeStorageHandler, + QueryPerformanceHandler, + RelationshipAnalysisHandler, +) + +__all__ = [ + # Core classes + "StatsCollector", + "DatasetStats", + "QueryAnalyzer", + "UsageAnalyzer", + "RelationshipAnalyzer", + "StorageOptimizer", + "IndexAdvisor", + "PerformanceBenchmark", + # Tool handlers + "GetDatasetStatsHandler", + "AnalyzeUsageHandler", + "QueryPerformanceHandler", + "RelationshipAnalysisHandler", + "OptimizeStorageHandler", + "IndexRecommendationsHandler", + "BenchmarkOperationsHandler", + "ExportMetricsHandler", +] diff --git a/contextframe/mcp/analytics/analyzer.py b/contextframe/mcp/analytics/analyzer.py new file mode 100644 index 0000000..605f2da --- /dev/null +++ b/contextframe/mcp/analytics/analyzer.py @@ -0,0 +1,681 @@ +"""Query, usage, and relationship analysis for ContextFrame datasets.""" + +import asyncio +import numpy as np +import pyarrow.compute as pc +import time +from collections import defaultdict, deque +from contextframe.frame import FrameDataset +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Deque, Dict, List, Optional, Set, Tuple + + +@dataclass +class QueryMetrics: + """Metrics for a single query execution.""" + + query_type: str # "vector", "text", "hybrid", "filter" + query_text: str | None = None + filter_expression: str | None = None + duration_ms: float = 0.0 + rows_scanned: int = 0 + rows_returned: int = 0 + index_used: bool = False + timestamp: datetime = field(default_factory=datetime.now) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "type": self.query_type, + "query": self.query_text, + "filter": self.filter_expression, + "duration_ms": round(self.duration_ms, 2), + "rows_scanned": self.rows_scanned, + "rows_returned": self.rows_returned, + "index_used": self.index_used, + "timestamp": self.timestamp.isoformat(), + } + + +class QueryAnalyzer: + """Analyzes query patterns and performance.""" + + def __init__(self, dataset: FrameDataset, max_history: int = 10000): + """Initialize query analyzer. + + Args: + dataset: The dataset to analyze + max_history: Maximum query history to maintain + """ + self.dataset = dataset + self.max_history = max_history + self.query_history: deque[QueryMetrics] = deque(maxlen=max_history) + self._query_cache: dict[str, list[QueryMetrics]] = defaultdict(list) + + def record_query(self, metrics: QueryMetrics) -> None: + """Record a query execution.""" + self.query_history.append(metrics) + self._query_cache[metrics.query_type].append(metrics) + + async def analyze_performance( + self, + time_range: timedelta | None = None, + query_type: str | None = None, + min_duration_ms: float = 0.0, + ) -> dict[str, Any]: + """Analyze query performance. + + Args: + time_range: Analyze queries within this time range + query_type: Filter by query type + min_duration_ms: Only include queries slower than this + + Returns: + Performance analysis results + """ + # Filter queries + queries = list(self.query_history) + + if time_range: + cutoff = datetime.now() - time_range + queries = [q for q in queries if q.timestamp >= cutoff] + + if query_type: + queries = [q for q in queries if q.query_type == query_type] + + if min_duration_ms > 0: + queries = [q for q in queries if q.duration_ms >= min_duration_ms] + + if not queries: + return {"message": "No queries match the criteria"} + + # Calculate statistics + durations = [q.duration_ms for q in queries] + rows_scanned = [q.rows_scanned for q in queries] + + # Group by query type + by_type = defaultdict(list) + for q in queries: + by_type[q.query_type].append(q) + + # Find slow queries + slow_queries = sorted(queries, key=lambda q: q.duration_ms, reverse=True)[:10] + + # Identify patterns + filter_patterns = defaultdict(int) + for q in queries: + if q.filter_expression: + # Simple pattern extraction (could be enhanced) + if "=" in q.filter_expression: + field = q.filter_expression.split("=")[0].strip() + filter_patterns[field] += 1 + + return { + "summary": { + "total_queries": len(queries), + "avg_duration_ms": round(np.mean(durations), 2), + "p50_duration_ms": round(np.percentile(durations, 50), 2), + "p90_duration_ms": round(np.percentile(durations, 90), 2), + "p99_duration_ms": round(np.percentile(durations, 99), 2), + "max_duration_ms": round(max(durations), 2), + "avg_rows_scanned": round(np.mean(rows_scanned), 0), + }, + "by_type": { + qtype: { + "count": len(queries), + "avg_duration_ms": round( + np.mean([q.duration_ms for q in queries]), 2 + ), + "index_usage_rate": sum(1 for q in queries if q.index_used) + / len(queries), + } + for qtype, queries in by_type.items() + }, + "slow_queries": [ + { + "query": q.to_dict(), + "optimization_hints": self._get_optimization_hints(q), + } + for q in slow_queries + ], + "filter_patterns": dict(filter_patterns), + } + + def _get_optimization_hints(self, query: QueryMetrics) -> list[str]: + """Generate optimization hints for a query.""" + hints = [] + + # Check index usage + if not query.index_used and query.query_type in ["vector", "text"]: + hints.append(f"Consider creating a {query.query_type} index") + + # Check scan efficiency + if query.rows_scanned > 0 and query.rows_returned > 0: + selectivity = query.rows_returned / query.rows_scanned + if selectivity < 0.01: # Less than 1% selectivity + hints.append("Very low selectivity - consider more specific filters") + + # Check duration + if query.duration_ms > 1000: + hints.append("Query taking over 1 second - review query complexity") + + # Filter suggestions + if query.filter_expression and "OR" in query.filter_expression: + hints.append("OR conditions can be slow - consider using IN operator") + + return hints + + +class UsageAnalyzer: + """Analyzes dataset usage patterns.""" + + def __init__(self, dataset: FrameDataset): + """Initialize usage analyzer.""" + self.dataset = dataset + self._access_log: dict[str, list[datetime]] = defaultdict(list) + self._operation_counts: dict[str, int] = defaultdict(int) + + def record_access(self, document_id: str, operation: str = "read") -> None: + """Record a document access.""" + self._access_log[document_id].append(datetime.now()) + self._operation_counts[operation] += 1 + + async def analyze_usage( + self, + time_range: timedelta | None = None, + group_by: str = "hour", + include_patterns: bool = True, + ) -> dict[str, Any]: + """Analyze usage patterns. + + Args: + time_range: Period to analyze + group_by: Grouping period (hour, day, week) + include_patterns: Include access pattern analysis + + Returns: + Usage analysis results + """ + # Filter by time range + cutoff = None + if time_range: + cutoff = datetime.now() - time_range + + # Get document metadata for enrichment + doc_metadata = await self._get_document_metadata() + + # Analyze access patterns + access_stats = self._analyze_access_patterns(cutoff) + + # Time-based analysis + time_stats = self._analyze_temporal_patterns(cutoff, group_by) + + # Collection usage + collection_stats = await self._analyze_collection_usage(cutoff) + + # Operation statistics + operation_stats = dict(self._operation_counts) + + results = { + "summary": { + "total_accesses": sum( + len(accesses) for accesses in self._access_log.values() + ), + "unique_documents": len(self._access_log), + "operations": operation_stats, + }, + "access_patterns": access_stats, + "temporal_patterns": time_stats, + "collection_usage": collection_stats, + } + + if include_patterns: + results["recommendations"] = self._generate_usage_recommendations( + access_stats, collection_stats + ) + + return results + + async def _get_document_metadata(self) -> dict[str, dict[str, Any]]: + """Get metadata for accessed documents.""" + if not self._access_log: + return {} + + doc_ids = list(self._access_log.keys()) + metadata = {} + + # Batch fetch metadata + for batch_start in range(0, len(doc_ids), 100): + batch_ids = doc_ids[batch_start : batch_start + 100] + filter_expr = " OR ".join(f"id = '{doc_id}'" for doc_id in batch_ids) + + scanner = self.dataset.scanner( + columns=["id", "record_type", "context"], filter=filter_expr + ) + + for batch in scanner.to_batches(): + ids = batch.column("id").to_pylist() + types = batch.column("record_type").to_pylist() + contexts = batch.column("context").to_pylist() + + for doc_id, doc_type, context in zip( + ids, types, contexts, strict=False + ): + metadata[doc_id] = { + "type": doc_type, + "collection_id": context.get("collection_id") + if context + else None, + } + + return metadata + + def _analyze_access_patterns(self, cutoff: datetime | None) -> dict[str, Any]: + """Analyze document access patterns.""" + access_counts = {} + recent_accesses = {} + + for doc_id, accesses in self._access_log.items(): + if cutoff: + accesses = [a for a in accesses if a >= cutoff] + + if accesses: + access_counts[doc_id] = len(accesses) + recent_accesses[doc_id] = max(accesses) + + if not access_counts: + return {} + + # Find hot documents + sorted_docs = sorted(access_counts.items(), key=lambda x: x[1], reverse=True) + hot_documents = sorted_docs[:10] + + # Calculate access distribution + counts = list(access_counts.values()) + + return { + "hot_documents": [ + {"id": doc_id, "access_count": count} for doc_id, count in hot_documents + ], + "access_distribution": { + "mean": round(np.mean(counts), 2), + "median": round(np.median(counts), 2), + "p90": round(np.percentile(counts, 90), 2), + "max": max(counts), + }, + "total_accessed": len(access_counts), + } + + def _analyze_temporal_patterns( + self, cutoff: datetime | None, group_by: str + ) -> dict[str, Any]: + """Analyze temporal access patterns.""" + # Flatten all accesses + all_accesses = [] + for accesses in self._access_log.values(): + if cutoff: + all_accesses.extend([a for a in accesses if a >= cutoff]) + else: + all_accesses.extend(accesses) + + if not all_accesses: + return {} + + # Group by time period + time_buckets = defaultdict(int) + + for access_time in all_accesses: + if group_by == "hour": + bucket = access_time.replace(minute=0, second=0, microsecond=0) + elif group_by == "day": + bucket = access_time.replace(hour=0, minute=0, second=0, microsecond=0) + elif group_by == "week": + # Start of week + days_since_monday = access_time.weekday() + bucket = access_time.replace(hour=0, minute=0, second=0, microsecond=0) + bucket -= timedelta(days=days_since_monday) + else: + bucket = access_time # Default to exact time + + time_buckets[bucket] += 1 + + # Convert to sorted list + sorted_buckets = sorted(time_buckets.items()) + + return { + "time_series": [ + {"time": t.isoformat(), "count": count} for t, count in sorted_buckets + ], + "peak_period": max(time_buckets.items(), key=lambda x: x[1])[0].isoformat(), + "total_periods": len(time_buckets), + } + + async def _analyze_collection_usage( + self, cutoff: datetime | None + ) -> dict[str, Any]: + """Analyze usage by collection.""" + doc_metadata = await self._get_document_metadata() + + collection_accesses = defaultdict(int) + collection_docs = defaultdict(set) + + for doc_id, accesses in self._access_log.items(): + if cutoff: + accesses = [a for a in accesses if a >= cutoff] + + if accesses and doc_id in doc_metadata: + coll_id = doc_metadata[doc_id].get("collection_id") + if coll_id: + collection_accesses[coll_id] += len(accesses) + collection_docs[coll_id].add(doc_id) + + if not collection_accesses: + return {} + + # Sort by access count + sorted_collections = sorted( + collection_accesses.items(), key=lambda x: x[1], reverse=True + ) + + return { + "most_accessed": [ + { + "collection_id": coll_id, + "access_count": count, + "unique_documents": len(collection_docs[coll_id]), + } + for coll_id, count in sorted_collections[:10] + ], + "total_collections": len(collection_accesses), + } + + def _generate_usage_recommendations( + self, access_stats: dict[str, Any], collection_stats: dict[str, Any] + ) -> list[str]: + """Generate recommendations based on usage patterns.""" + recommendations = [] + + # Hot document caching + if "hot_documents" in access_stats and access_stats["hot_documents"]: + top_doc = access_stats["hot_documents"][0] + if top_doc["access_count"] > 100: + recommendations.append( + f"Consider caching frequently accessed documents " + f"(top document accessed {top_doc['access_count']} times)" + ) + + # Access distribution + if "access_distribution" in access_stats: + dist = access_stats["access_distribution"] + if dist["max"] > dist["mean"] * 10: + recommendations.append( + "Highly skewed access pattern detected - " + "consider optimizing for hot path" + ) + + # Collection patterns + if "most_accessed" in collection_stats and collection_stats["most_accessed"]: + top_coll = collection_stats["most_accessed"][0] + recommendations.append( + f"Collection '{top_coll['collection_id']}' is most active - " + "ensure it has appropriate indices" + ) + + return recommendations + + +class RelationshipAnalyzer: + """Analyzes document relationships and graph structure.""" + + def __init__(self, dataset: FrameDataset): + """Initialize relationship analyzer.""" + self.dataset = dataset + self._graph_cache: dict[str, list[tuple[str, str]]] | None = None + + async def analyze_relationships( + self, + max_depth: int = 3, + relationship_types: list[str] | None = None, + include_orphans: bool = True, + ) -> dict[str, Any]: + """Analyze document relationship graph. + + Args: + max_depth: Maximum traversal depth + relationship_types: Types to include (None = all) + include_orphans: Include unconnected documents + + Returns: + Relationship analysis results + """ + # Build relationship graph + graph = await self._build_relationship_graph(relationship_types) + + # Calculate graph metrics + metrics = self._calculate_graph_metrics(graph) + + # Find connected components + components = self._find_connected_components(graph) + + # Analyze relationship patterns + patterns = self._analyze_relationship_patterns(graph) + + # Find circular dependencies + cycles = self._find_cycles(graph, max_depth) + + results = { + "summary": metrics, + "components": { + "count": len(components), + "sizes": [len(c) for c in components[:10]], # Top 10 + "largest_component": len(components[0]) if components else 0, + }, + "patterns": patterns, + "cycles": { + "found": len(cycles) > 0, + "count": len(cycles), + "examples": cycles[:5], # First 5 cycles + }, + } + + if include_orphans: + orphans = await self._find_orphaned_documents(graph) + results["orphans"] = { + "count": len(orphans), + "document_ids": orphans[:20], # First 20 + } + + return results + + async def _build_relationship_graph( + self, relationship_types: list[str] | None = None + ) -> dict[str, list[tuple[str, str]]]: + """Build document relationship graph.""" + if self._graph_cache is not None: + return self._graph_cache + + graph = defaultdict(list) + + # Scan relationships + scanner = self.dataset.scanner(columns=["id", "relationships"]) + + for batch in scanner.to_batches(): + ids = batch.column("id").to_pylist() + relationships_list = batch.column("relationships").to_pylist() + + for doc_id, relationships in zip(ids, relationships_list, strict=False): + if relationships: + for rel in relationships: + if isinstance(rel, dict): + rel_type = rel.get("type", "unknown") + target = rel.get("target") + + if ( + relationship_types is None + or rel_type in relationship_types + ): + if target: + graph[doc_id].append((rel_type, target)) + + self._graph_cache = dict(graph) + return self._graph_cache + + def _calculate_graph_metrics( + self, graph: dict[str, list[tuple[str, str]]] + ) -> dict[str, Any]: + """Calculate basic graph metrics.""" + nodes = set(graph.keys()) + all_targets = set() + edge_count = 0 + + for edges in graph.values(): + edge_count += len(edges) + all_targets.update(target for _, target in edges) + + nodes.update(all_targets) + + # Degree distribution + in_degree = defaultdict(int) + out_degree = defaultdict(int) + + for source, edges in graph.items(): + out_degree[source] = len(edges) + for _, target in edges: + in_degree[target] += 1 + + degrees = list(out_degree.values()) + [0] * (len(nodes) - len(out_degree)) + + return { + "node_count": len(nodes), + "edge_count": edge_count, + "avg_degree": round(edge_count / len(nodes) if nodes else 0, 2), + "max_out_degree": max(out_degree.values()) if out_degree else 0, + "max_in_degree": max(in_degree.values()) if in_degree else 0, + "degree_distribution": { + "mean": round(np.mean(degrees), 2), + "median": round(np.median(degrees), 2), + "std": round(np.std(degrees), 2), + }, + } + + def _find_connected_components( + self, graph: dict[str, list[tuple[str, str]]] + ) -> list[set[str]]: + """Find connected components using DFS.""" + # Build undirected adjacency list + adjacency = defaultdict(set) + all_nodes = set(graph.keys()) + + for source, edges in graph.items(): + for _, target in edges: + adjacency[source].add(target) + adjacency[target].add(source) + all_nodes.add(target) + + # DFS to find components + visited = set() + components = [] + + for node in all_nodes: + if node not in visited: + component = set() + stack = [node] + + while stack: + current = stack.pop() + if current not in visited: + visited.add(current) + component.add(current) + stack.extend(adjacency[current] - visited) + + components.append(component) + + # Sort by size (largest first) + components.sort(key=len, reverse=True) + + return components + + def _analyze_relationship_patterns( + self, graph: dict[str, list[tuple[str, str]]] + ) -> dict[str, Any]: + """Analyze patterns in relationships.""" + type_counts = defaultdict(int) + type_pairs = defaultdict(int) + + for edges in graph.values(): + for rel_type, _ in edges: + type_counts[rel_type] += 1 + + # Count type pairs + edge_types = [rel_type for rel_type, _ in edges] + for i, type1 in enumerate(edge_types): + for type2 in edge_types[i + 1 :]: + pair = tuple(sorted([type1, type2])) + type_pairs[pair] += 1 + + return { + "type_distribution": dict(type_counts), + "common_pairs": [ + {"types": list(pair), "count": count} + for pair, count in sorted( + type_pairs.items(), key=lambda x: x[1], reverse=True + )[:5] + ], + } + + def _find_cycles( + self, graph: dict[str, list[tuple[str, str]]], max_depth: int + ) -> list[list[str]]: + """Find cycles in the relationship graph.""" + cycles = [] + + def dfs(node: str, path: list[str], visited: set[str]) -> None: + if len(path) > max_depth: + return + + if node in path: + # Found a cycle + cycle_start = path.index(node) + cycle = path[cycle_start:] + [node] + if len(cycle) > 2: # Ignore self-loops + cycles.append(cycle) + return + + if node in visited: + return + + visited.add(node) + path.append(node) + + for _, target in graph.get(node, []): + dfs(target, path.copy(), visited.copy()) + + # Start DFS from each node + for start_node in graph: + dfs(start_node, [], set()) + if len(cycles) >= 10: # Limit cycles found + break + + return cycles + + async def _find_orphaned_documents( + self, graph: dict[str, list[tuple[str, str]]] + ) -> list[str]: + """Find documents with no relationships.""" + # Get all documents + all_docs = set() + scanner = self.dataset.scanner(columns=["id"], limit=10000) + + for batch in scanner.to_batches(): + all_docs.update(batch.column("id").to_pylist()) + + # Find connected documents + connected = set(graph.keys()) + for edges in graph.values(): + connected.update(target for _, target in edges) + + # Orphans are documents not in the graph + orphans = list(all_docs - connected) + + return orphans[:100] # Return first 100 diff --git a/contextframe/mcp/analytics/optimizer.py b/contextframe/mcp/analytics/optimizer.py new file mode 100644 index 0000000..897942e --- /dev/null +++ b/contextframe/mcp/analytics/optimizer.py @@ -0,0 +1,648 @@ +"""Storage optimization, index recommendations, and performance benchmarking.""" + +import asyncio +import numpy as np +import time +from contextframe.frame import FrameDataset +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Set, Tuple + + +@dataclass +class OptimizationResult: + """Result of an optimization operation.""" + + operation: str + success: bool + metrics: dict[str, Any] + duration_seconds: float + timestamp: datetime + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "operation": self.operation, + "success": self.success, + "metrics": self.metrics, + "duration_seconds": round(self.duration_seconds, 2), + "timestamp": self.timestamp.isoformat(), + } + + +class StorageOptimizer: + """Optimizes Lance dataset storage using native capabilities.""" + + def __init__(self, dataset: FrameDataset): + """Initialize storage optimizer.""" + self.dataset = dataset + self._optimization_history: list[OptimizationResult] = [] + + async def optimize_storage( + self, + operations: list[str] = ["compact", "vacuum"], + dry_run: bool = False, + target_version: int | None = None, + ) -> dict[str, Any]: + """Optimize dataset storage. + + Args: + operations: List of operations to perform + dry_run: Preview changes without applying + target_version: Optimize to specific version + + Returns: + Optimization results + """ + results = { + "operations": [], + "total_space_saved_mb": 0.0, + "total_duration_seconds": 0.0, + } + + # Get initial stats + initial_stats = self.dataset.get_dataset_stats() + + for operation in operations: + start_time = time.time() + + if operation == "compact": + result = await self._compact_files(dry_run) + elif operation == "vacuum": + result = await self._vacuum_old_versions(dry_run, target_version) + elif operation == "reindex": + result = await self._optimize_indices(dry_run) + else: + result = { + "error": f"Unknown operation: {operation}", + "success": False, + } + + duration = time.time() - start_time + + # Record result + opt_result = OptimizationResult( + operation=operation, + success=result.get("success", False), + metrics=result, + duration_seconds=duration, + timestamp=datetime.now(), + ) + self._optimization_history.append(opt_result) + + results["operations"].append(opt_result.to_dict()) + results["total_duration_seconds"] += duration + + if "space_saved_mb" in result: + results["total_space_saved_mb"] += result["space_saved_mb"] + + # Get final stats + if not dry_run: + final_stats = self.dataset.get_dataset_stats() + results["before"] = initial_stats + results["after"] = final_stats + + return results + + async def _compact_files(self, dry_run: bool) -> dict[str, Any]: + """Compact dataset files.""" + try: + if dry_run: + # Preview compaction + fragments = self.dataset.get_fragment_stats() + small_fragments = [f for f in fragments if f["num_rows"] < 10000] + + return { + "success": True, + "preview": True, + "fragments_to_compact": len(small_fragments), + "estimated_fragments_after": max( + 1, len(fragments) - len(small_fragments) + 1 + ), + } + else: + # Perform compaction + result = self.dataset.compact_files() + + return { + "success": True, + "fragments_compacted": result.get("fragments_compacted", 0), + "files_removed": result.get("files_removed", 0), + "files_added": result.get("files_added", 0), + } + except Exception as e: + return { + "success": False, + "error": str(e), + } + + async def _vacuum_old_versions( + self, dry_run: bool, target_version: int | None = None + ) -> dict[str, Any]: + """Clean up old dataset versions.""" + try: + # Calculate age threshold + if target_version is not None: + current_version = self.dataset._dataset.version + versions_to_keep = current_version - target_version + older_than = timedelta(days=versions_to_keep) # Rough estimate + else: + older_than = timedelta(days=7) # Default: keep last week + + if dry_run: + # Estimate cleanup + version_history = self.dataset.get_version_history() + old_versions = [ + v + for v in version_history + if v["version"] < (target_version or len(version_history) - 10) + ] + + return { + "success": True, + "preview": True, + "versions_to_remove": len(old_versions), + "estimated_space_mb": len(old_versions) * 10, # Rough estimate + } + else: + # Perform cleanup + result = self.dataset.cleanup_old_versions(older_than=older_than) + + return { + "success": True, + "bytes_removed": result.get("bytes_removed", 0), + "space_saved_mb": result.get("bytes_removed", 0) / (1024 * 1024), + "old_versions_removed": result.get("old_versions_removed", 0), + } + except Exception as e: + return { + "success": False, + "error": str(e), + } + + async def _optimize_indices(self, dry_run: bool) -> dict[str, Any]: + """Optimize dataset indices.""" + try: + indices = self.dataset.list_indices() + + if dry_run: + return { + "success": True, + "preview": True, + "indices_to_optimize": len(indices), + "index_names": [idx["name"] for idx in indices], + } + else: + # Perform optimization + result = self.dataset.optimize_indices() + + return { + "success": True, + "indices_optimized": result.get("indices_optimized", 0), + "status": result.get("status", "completed"), + } + except Exception as e: + return { + "success": False, + "error": str(e), + } + + def get_optimization_history(self) -> list[dict[str, Any]]: + """Get history of optimization operations.""" + return [opt.to_dict() for opt in self._optimization_history] + + +class IndexAdvisor: + """Provides index recommendations based on query patterns.""" + + def __init__(self, dataset: FrameDataset): + """Initialize index advisor.""" + self.dataset = dataset + self._query_patterns: dict[str, int] = {} + self._field_usage: dict[str, int] = {} + + def record_query_pattern(self, filter_expr: str, fields: list[str]) -> None: + """Record a query pattern for analysis.""" + self._query_patterns[filter_expr] = self._query_patterns.get(filter_expr, 0) + 1 + for field in fields: + self._field_usage[field] = self._field_usage.get(field, 0) + 1 + + async def get_recommendations( + self, analyze_queries: bool = True, workload_type: str = "mixed" + ) -> dict[str, Any]: + """Get index recommendations. + + Args: + analyze_queries: Analyze recent query patterns + workload_type: Type of workload (search, analytics, mixed) + + Returns: + Index recommendations + """ + # Get current indices + current_indices = self.dataset.list_indices() + indexed_fields = set() + + for idx in current_indices: + indexed_fields.update(idx.get("fields", [])) + + # Analyze schema + schema_fields = self._analyze_schema() + + # Generate recommendations + recommendations = [] + + # Vector index recommendations + if "embedding" not in indexed_fields and workload_type in ["search", "mixed"]: + recommendations.append( + { + "type": "vector", + "field": "embedding", + "reason": "No vector index found for embedding field", + "priority": "high", + "estimated_benefit": "10-100x faster similarity search", + "command": "dataset.create_vector_index('embedding', metric='cosine', num_partitions=256)", + } + ) + + # Scalar index recommendations + scalar_candidates = self._identify_scalar_candidates( + schema_fields, indexed_fields + ) + + for field, info in scalar_candidates.items(): + recommendations.append( + { + "type": "scalar", + "field": field, + "reason": info["reason"], + "priority": info["priority"], + "estimated_benefit": f"{info['benefit']}x faster filtering", + "command": f"dataset.create_scalar_index('{field}')", + } + ) + + # Full-text search index + if "content" not in indexed_fields and workload_type in ["search", "mixed"]: + recommendations.append( + { + "type": "fts", + "field": "content", + "reason": "No full-text search index for content field", + "priority": "medium", + "estimated_benefit": "Enable text search capabilities", + "command": "dataset.create_scalar_index('content')", + } + ) + + # Analyze redundant indices + redundant = self._find_redundant_indices(current_indices) + + # Usage statistics + usage_stats = await self._analyze_index_usage(current_indices) + + return { + "current_indices": [ + { + "name": idx["name"], + "type": idx["type"], + "fields": idx["fields"], + "usage": usage_stats.get(idx["name"], "unknown"), + } + for idx in current_indices + ], + "recommendations": sorted( + recommendations, + key=lambda x: {"high": 0, "medium": 1, "low": 2}.get(x["priority"], 3), + ), + "redundant_indices": redundant, + "index_coverage": { + "total_fields": len(schema_fields), + "indexed_fields": len(indexed_fields), + "coverage_percent": round( + len(indexed_fields) / len(schema_fields) * 100, 1 + ), + }, + } + + def _analyze_schema(self) -> dict[str, dict[str, Any]]: + """Analyze dataset schema for indexable fields.""" + schema = self.dataset._dataset.schema + fields = {} + + for field in schema: + field_type = str(field.type) + fields[field.name] = { + "type": field_type, + "nullable": field.nullable, + "metadata": dict(field.metadata) if field.metadata else {}, + } + + return fields + + def _identify_scalar_candidates( + self, schema_fields: dict[str, dict[str, Any]], indexed_fields: set[str] + ) -> dict[str, dict[str, Any]]: + """Identify fields that would benefit from scalar indices.""" + candidates = {} + + # High-value fields for indexing + high_value_fields = { + "id": ("Primary key field", "high", 100), + "record_type": ("Frequently filtered field", "high", 50), + "created_at": ("Temporal queries", "medium", 20), + "updated_at": ("Temporal queries", "medium", 20), + "source_type": ("Content filtering", "medium", 10), + } + + for field, (reason, priority, benefit) in high_value_fields.items(): + if field in schema_fields and field not in indexed_fields: + candidates[field] = { + "reason": reason, + "priority": priority, + "benefit": benefit, + } + + # Check field usage patterns + for field, usage_count in self._field_usage.items(): + if field not in indexed_fields and usage_count > 10: + if field not in candidates: + candidates[field] = { + "reason": f"Frequently queried field ({usage_count} times)", + "priority": "medium" if usage_count > 50 else "low", + "benefit": min(usage_count, 50), + } + + return candidates + + def _find_redundant_indices( + self, current_indices: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """Find potentially redundant indices.""" + redundant = [] + + # Check for duplicate indices on same fields + field_indices = {} + for idx in current_indices: + fields_key = tuple(sorted(idx.get("fields", []))) + if fields_key in field_indices: + redundant.append( + { + "index": idx["name"], + "reason": f"Duplicate of {field_indices[fields_key]}", + "action": "Consider removing", + } + ) + else: + field_indices[fields_key] = idx["name"] + + return redundant + + async def _analyze_index_usage( + self, indices: list[dict[str, Any]] + ) -> dict[str, str]: + """Analyze index usage (simplified).""" + usage = {} + + for idx in indices: + # In a real implementation, this would query Lance statistics + # For now, return estimated usage based on field + if idx["type"] == "vector" or idx["fields"][0] in ["id", "record_type"]: + usage[idx["name"]] = "high" + else: + usage[idx["name"]] = "medium" + + return usage + + +class PerformanceBenchmark: + """Benchmarks dataset operations for performance analysis.""" + + def __init__(self, dataset: FrameDataset): + """Initialize performance benchmark.""" + self.dataset = dataset + self._results: list[dict[str, Any]] = [] + + async def benchmark_operations( + self, + operations: list[str] = ["search", "insert", "update", "scan"], + sample_size: int = 100, + concurrency: int = 1, + ) -> dict[str, Any]: + """Benchmark key dataset operations. + + Args: + operations: Operations to benchmark + sample_size: Number of operations per benchmark + concurrency: Number of concurrent operations + + Returns: + Benchmark results + """ + results = { + "configuration": { + "sample_size": sample_size, + "concurrency": concurrency, + "timestamp": datetime.now().isoformat(), + }, + "operations": {}, + "summary": {}, + } + + for operation in operations: + if operation == "search": + op_results = await self._benchmark_search(sample_size, concurrency) + elif operation == "insert": + op_results = await self._benchmark_insert(sample_size, concurrency) + elif operation == "update": + op_results = await self._benchmark_update(sample_size, concurrency) + elif operation == "scan": + op_results = await self._benchmark_scan(sample_size, concurrency) + else: + continue + + results["operations"][operation] = op_results + self._results.append( + { + "operation": operation, + "results": op_results, + "timestamp": datetime.now(), + } + ) + + # Calculate summary + results["summary"] = self._calculate_summary(results["operations"]) + + return results + + async def _benchmark_search( + self, sample_size: int, concurrency: int + ) -> dict[str, Any]: + """Benchmark search operations.""" + latencies = [] + + # Get sample documents for queries + sample_docs = [] + scanner = self.dataset.scanner( + columns=["id", "embedding"], + filter="embedding IS NOT NULL", + limit=min(10, sample_size), + ) + + for batch in scanner.to_batches(): + ids = batch.column("id").to_pylist() + embeddings = batch.column("embedding").to_pylist() + for doc_id, emb in zip(ids, embeddings, strict=False): + if emb: + sample_docs.append((doc_id, emb)) + + if not sample_docs: + return {"error": "No documents with embeddings found"} + + # Run search benchmarks + async def run_search(): + _, emb = sample_docs[np.random.randint(0, len(sample_docs))] + start = time.time() + try: + results = self.dataset.knn_search(query_vector=emb, k=10, filter=None) + duration = (time.time() - start) * 1000 # ms + return duration, len(results) + except Exception as e: + return None, str(e) + + # Run concurrent searches + for _ in range(0, sample_size, concurrency): + batch_tasks = [ + run_search() + for _ in range(min(concurrency, sample_size - len(latencies))) + ] + batch_results = await asyncio.gather(*batch_tasks) + + for duration, result in batch_results: + if duration is not None: + latencies.append(duration) + + if not latencies: + return {"error": "No successful search operations"} + + return self._calculate_latency_stats(latencies, "search") + + async def _benchmark_insert( + self, sample_size: int, concurrency: int + ) -> dict[str, Any]: + """Benchmark insert operations.""" + # For safety, we'll simulate inserts rather than actually inserting + # In production, this would create test documents + + latencies = [] + + # Simulate insert timing based on dataset characteristics + base_latency = 10.0 # ms + variance = 5.0 + + for _ in range(sample_size): + simulated_latency = base_latency + np.random.normal(0, variance) + latencies.append(max(0.1, simulated_latency)) + + return self._calculate_latency_stats(latencies, "insert (simulated)") + + async def _benchmark_update( + self, sample_size: int, concurrency: int + ) -> dict[str, Any]: + """Benchmark update operations.""" + # Similar to insert, simulate for safety + latencies = [] + + base_latency = 15.0 # ms (updates typically slower) + variance = 7.0 + + for _ in range(sample_size): + simulated_latency = base_latency + np.random.normal(0, variance) + latencies.append(max(0.1, simulated_latency)) + + return self._calculate_latency_stats(latencies, "update (simulated)") + + async def _benchmark_scan( + self, sample_size: int, concurrency: int + ) -> dict[str, Any]: + """Benchmark scan operations.""" + latencies = [] + + # Test different scan sizes + scan_sizes = [10, 100, 1000] + + for size in scan_sizes: + for _ in range(sample_size // len(scan_sizes)): + start = time.time() + try: + scanner = self.dataset.scanner(limit=size) + count = 0 + for batch in scanner.to_batches(): + count += len(batch) + duration = (time.time() - start) * 1000 # ms + latencies.append(duration) + except Exception: + continue + + if not latencies: + return {"error": "No successful scan operations"} + + return self._calculate_latency_stats(latencies, "scan") + + def _calculate_latency_stats( + self, latencies: list[float], operation: str + ) -> dict[str, Any]: + """Calculate latency statistics.""" + return { + "operation": operation, + "sample_count": len(latencies), + "latency_ms": { + "min": round(min(latencies), 2), + "p50": round(np.percentile(latencies, 50), 2), + "p90": round(np.percentile(latencies, 90), 2), + "p99": round(np.percentile(latencies, 99), 2), + "max": round(max(latencies), 2), + "mean": round(np.mean(latencies), 2), + "std": round(np.std(latencies), 2), + }, + "throughput_ops_per_sec": round(1000 / np.mean(latencies), 1), + } + + def _calculate_summary(self, operations: dict[str, Any]) -> dict[str, Any]: + """Calculate summary statistics across operations.""" + summary = { + "fastest_operation": None, + "slowest_operation": None, + "performance_score": 0.0, + } + + mean_latencies = {} + for op_name, op_data in operations.items(): + if "latency_ms" in op_data: + mean_latencies[op_name] = op_data["latency_ms"]["mean"] + + if mean_latencies: + summary["fastest_operation"] = min( + mean_latencies.items(), key=lambda x: x[1] + ) + summary["slowest_operation"] = max( + mean_latencies.items(), key=lambda x: x[1] + ) + + # Simple performance score (lower is better) + summary["performance_score"] = round( + sum(mean_latencies.values()) / len(mean_latencies), 2 + ) + + return summary + + def get_benchmark_history(self) -> list[dict[str, Any]]: + """Get history of benchmark results.""" + return [ + { + "operation": r["operation"], + "timestamp": r["timestamp"].isoformat(), + "mean_latency_ms": r["results"].get("latency_ms", {}).get("mean"), + } + for r in self._results + ] diff --git a/contextframe/mcp/analytics/stats.py b/contextframe/mcp/analytics/stats.py new file mode 100644 index 0000000..0adce8b --- /dev/null +++ b/contextframe/mcp/analytics/stats.py @@ -0,0 +1,412 @@ +"""Dataset statistics and metrics collection using Lance native capabilities.""" + +import asyncio +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc +import time +from contextframe.frame import FrameDataset, FrameRecord +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Set, Tuple + + +@dataclass +class DatasetStats: + """Container for comprehensive dataset statistics.""" + + # Basic counts from Lance stats + total_documents: int = 0 + total_collections: int = 0 + total_relationships: int = 0 + + # Storage metrics from Lance + num_fragments: int = 0 + num_deleted_rows: int = 0 + num_small_files: int = 0 + storage_size_bytes: int = 0 + + # Version metrics + current_version: int = 0 + latest_version: int = 0 + version_count: int = 0 + + # Content metrics + document_types: dict[str, int] = field(default_factory=dict) + collection_sizes: dict[str, int] = field(default_factory=dict) + metadata_fields: dict[str, int] = field(default_factory=dict) + + # Embedding metrics + embedding_coverage: float = 0.0 + embedding_dimensions: set[int] = field(default_factory=set) + + # Relationship metrics + relationship_types: dict[str, int] = field(default_factory=dict) + avg_relationships_per_doc: float = 0.0 + orphaned_documents: int = 0 + + # Index metrics + indices: list[dict[str, Any]] = field(default_factory=list) + indexed_fields: set[str] = field(default_factory=set) + + # Time-based metrics + oldest_document: datetime | None = None + newest_document: datetime | None = None + + # Performance metrics + avg_document_size_kb: float = 0.0 + fragment_efficiency: float = 0.0 # ratio of active to total rows + collection_time_seconds: float = 0.0 + + # Collection statistics + avg_collection_size: float = 0.0 + max_collection_size: int = 0 + min_collection_size: int = 0 + + def to_dict(self) -> dict[str, Any]: + """Convert stats to dictionary format.""" + return { + "summary": { + "total_documents": self.total_documents, + "total_collections": self.total_collections, + "total_relationships": self.total_relationships, + "storage_size_mb": round(self.storage_size_bytes / (1024 * 1024), 2), + }, + "storage": { + "num_fragments": self.num_fragments, + "num_deleted_rows": self.num_deleted_rows, + "num_small_files": self.num_small_files, + "avg_document_size_kb": round(self.avg_document_size_kb, 2), + "fragment_efficiency": round(self.fragment_efficiency, 3), + }, + "versions": { + "current": self.current_version, + "latest": self.latest_version, + "total_versions": self.version_count, + }, + "content": { + "document_types": self.document_types, + "collection_count": len(self.collection_sizes), + "collection_sizes": self.collection_sizes, + "metadata_fields": self.metadata_fields, + }, + "embeddings": { + "coverage": round(self.embedding_coverage, 3), + "dimensions": sorted(list(self.embedding_dimensions)), + }, + "relationships": { + "types": self.relationship_types, + "avg_per_document": round(self.avg_relationships_per_doc, 2), + "orphaned_documents": self.orphaned_documents, + }, + "indices": { + "count": len(self.indices), + "indexed_fields": sorted(list(self.indexed_fields)), + "details": self.indices, + }, + "time_range": { + "oldest": self.oldest_document.isoformat() + if self.oldest_document + else None, + "newest": self.newest_document.isoformat() + if self.newest_document + else None, + }, + } + + +class StatsCollector: + """Collects comprehensive statistics from a FrameDataset using Lance native features.""" + + def __init__(self, dataset: FrameDataset): + """Initialize stats collector. + + Args: + dataset: The FrameDataset to analyze + """ + self.dataset = dataset + self._stats = DatasetStats() + + async def collect_stats( + self, + include_content: bool = True, + include_fragments: bool = True, + include_relationships: bool = True, + sample_size: int | None = None, + ) -> DatasetStats: + """Collect all statistics using Lance native capabilities. + + Args: + include_content: Include content analysis + include_fragments: Include fragment-level analysis + include_relationships: Include relationship analysis + sample_size: If set, sample for expensive operations + + Returns: + DatasetStats object with comprehensive metrics + """ + start_time = time.time() + + # Get Lance native stats first + await self._collect_lance_stats() + + # Collect basic counts + await self._collect_basic_counts() + + # Run optional detailed analyses in parallel + tasks = [] + + if include_content: + tasks.append(self._collect_content_stats(sample_size)) + + if include_fragments: + tasks.append(self._collect_fragment_analysis()) + + if include_relationships: + tasks.append(self._collect_relationship_stats(sample_size)) + + if tasks: + await asyncio.gather(*tasks) + + # Calculate derived metrics + self._calculate_derived_metrics() + + collection_time = time.time() - start_time + self._stats.collection_time_seconds = collection_time + + return self._stats + + async def _collect_lance_stats(self) -> None: + """Collect statistics using Lance native methods.""" + # Get dataset stats from Lance + lance_stats = self.dataset.get_dataset_stats() + + # Extract Lance dataset stats + if 'dataset_stats' in lance_stats: + ds_stats = lance_stats['dataset_stats'] + self._stats.num_fragments = ds_stats.get('num_fragments', 0) + self._stats.num_deleted_rows = ds_stats.get('num_deleted_rows', 0) + self._stats.num_small_files = ds_stats.get('num_small_files', 0) + + # Version info + if 'version_info' in lance_stats: + v_info = lance_stats['version_info'] + self._stats.current_version = v_info.get('current_version', 0) + self._stats.latest_version = v_info.get('latest_version', 0) + self._stats.version_count = self._stats.latest_version + 1 + + # Storage info + if 'storage' in lance_stats: + self._stats.total_documents = lance_stats['storage'].get('num_rows', 0) + + # Index info + if 'indices' in lance_stats: + self._stats.indices = lance_stats['indices'] + for idx in self._stats.indices: + self._stats.indexed_fields.update(idx.get('fields', [])) + + async def _collect_basic_counts(self) -> None: + """Collect basic document and collection counts.""" + # Count collections using filter + collection_count = self.dataset.count_by_filter( + "record_type = 'collection_header'" + ) + self._stats.total_collections = collection_count + + async def _collect_fragment_analysis(self) -> None: + """Analyze fragment-level statistics.""" + fragments = self.dataset.get_fragment_stats() + + if fragments: + # Calculate storage size from fragments + total_size = 0 + active_rows = 0 + physical_rows = 0 + + for frag in fragments: + active_rows += frag['num_rows'] + physical_rows += frag['physical_rows'] + # Estimate size based on rows (if file info not available) + if self._stats.total_documents > 0: + avg_row_size = 1024 # Default estimate + total_size += frag['physical_rows'] * avg_row_size + + self._stats.storage_size_bytes = total_size + + # Calculate efficiency + if physical_rows > 0: + self._stats.fragment_efficiency = active_rows / physical_rows + + # Average document size + if self._stats.total_documents > 0: + self._stats.avg_document_size_kb = ( + self._stats.storage_size_bytes / self._stats.total_documents / 1024 + ) + + async def _collect_content_stats(self, sample_size: int | None = None) -> None: + """Collect content-related statistics.""" + # Document type distribution + doc_types: Dict[str, int] = {} + collection_members: Dict[str, int] = {} + metadata_fields: Dict[str, int] = {} + oldest = None + newest = None + + # Use scanner with projection for efficiency + columns = ["id", "record_type", "context", "custom_metadata", "created_at"] + + # Sample if needed + if sample_size and sample_size < self._stats.total_documents: + # Use limit for sampling + scanner = self.dataset.scanner(columns=columns, limit=sample_size) + else: + scanner = self.dataset.scanner(columns=columns) + + # Process batches + for batch in scanner.to_batches(): + # Document types + if "record_type" in batch.column_names: + types = batch.column("record_type").to_pylist() + for doc_type in types: + doc_types[doc_type] = doc_types.get(doc_type, 0) + 1 + + # Collection membership + if "context" in batch.column_names: + contexts = batch.column("context").to_pylist() + for context in contexts: + if ( + context + and isinstance(context, dict) + and "collection_id" in context + ): + coll_id = context["collection_id"] + collection_members[coll_id] = ( + collection_members.get(coll_id, 0) + 1 + ) + + # Metadata fields + if "custom_metadata" in batch.column_names: + metadatas = batch.column("custom_metadata").to_pylist() + for metadata in metadatas: + if metadata and isinstance(metadata, dict): + for field in metadata.keys(): + metadata_fields[field] = metadata_fields.get(field, 0) + 1 + + # Time metrics + if "created_at" in batch.column_names: + timestamps = batch.column("created_at").to_pylist() + for ts_str in timestamps: + if ts_str: + try: + ts = datetime.fromisoformat(ts_str.replace("Z", "+00:00")) + if oldest is None or ts < oldest: + oldest = ts + if newest is None or ts > newest: + newest = ts + except (ValueError, AttributeError): + continue + + # Update stats + self._stats.document_types = doc_types + self._stats.collection_sizes = collection_members + self._stats.metadata_fields = metadata_fields + self._stats.oldest_document = oldest + self._stats.newest_document = newest + + async def _collect_relationship_stats(self, sample_size: int | None = None) -> None: + """Collect relationship statistics.""" + relationship_types: Dict[str, int] = {} + docs_with_relationships = 0 + total_relationships = 0 + + # Use projection for efficiency + columns = ["id", "relationships"] + + if sample_size and sample_size < self._stats.total_documents: + scanner = self.dataset.scanner(columns=columns, limit=sample_size) + scaling_factor = self._stats.total_documents / sample_size + else: + scanner = self.dataset.scanner(columns=columns) + scaling_factor = 1.0 + + # Process relationships + for batch in scanner.to_batches(): + if "relationships" in batch.column_names: + relationships_list = batch.column("relationships").to_pylist() + + for relationships in relationships_list: + if relationships and isinstance(relationships, list): + docs_with_relationships += 1 + for rel in relationships: + if isinstance(rel, dict): + rel_type = rel.get("type", "unknown") + relationship_types[rel_type] = ( + relationship_types.get(rel_type, 0) + 1 + ) + total_relationships += 1 + + # Scale if sampled + if scaling_factor > 1: + docs_with_relationships = int(docs_with_relationships * scaling_factor) + total_relationships = int(total_relationships * scaling_factor) + for rel_type in relationship_types: + relationship_types[rel_type] = int( + relationship_types[rel_type] * scaling_factor + ) + + # Update stats + self._stats.relationship_types = relationship_types + self._stats.total_relationships = total_relationships + + if self._stats.total_documents > 0: + self._stats.avg_relationships_per_doc = ( + total_relationships / self._stats.total_documents + ) + self._stats.orphaned_documents = ( + self._stats.total_documents - docs_with_relationships + ) + + async def _collect_embedding_stats(self, sample_size: int | None = None) -> None: + """Collect embedding statistics.""" + total_with_embeddings = 0 + embedding_dims = set() + + # Use projection + columns = ["embedding"] + + if sample_size and sample_size < self._stats.total_documents: + scanner = self.dataset.scanner(columns=columns, limit=sample_size) + scaling_factor = self._stats.total_documents / sample_size + else: + scanner = self.dataset.scanner(columns=columns) + scaling_factor = 1.0 + + # Process embeddings + for batch in scanner.to_batches(): + if "embedding" in batch.column_names: + embeddings = batch.column("embedding").to_pylist() + + for emb in embeddings: + if emb is not None and len(emb) > 0: + total_with_embeddings += 1 + embedding_dims.add(len(emb)) + + # Scale and update + if scaling_factor > 1: + total_with_embeddings = int(total_with_embeddings * scaling_factor) + + if self._stats.total_documents > 0: + self._stats.embedding_coverage = ( + total_with_embeddings / self._stats.total_documents + ) + + self._stats.embedding_dimensions = embedding_dims + + def _calculate_derived_metrics(self) -> None: + """Calculate metrics derived from collected stats.""" + # Collection size statistics + if self._stats.collection_sizes: + sizes = list(self._stats.collection_sizes.values()) + self._stats.avg_collection_size = sum(sizes) / len(sizes) + self._stats.max_collection_size = max(sizes) + self._stats.min_collection_size = min(sizes) diff --git a/contextframe/mcp/analytics/tools.py b/contextframe/mcp/analytics/tools.py new file mode 100644 index 0000000..19a71a4 --- /dev/null +++ b/contextframe/mcp/analytics/tools.py @@ -0,0 +1,810 @@ +"""MCP tool implementations for analytics and performance monitoring.""" + +import json +from .analyzer import QueryAnalyzer, RelationshipAnalyzer, UsageAnalyzer +from .optimizer import IndexAdvisor, PerformanceBenchmark, StorageOptimizer +from .stats import StatsCollector +from contextframe.frame import FrameDataset +from contextframe.mcp.errors import ToolError +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional + + +class ToolHandler: + """Base class for analytics tool handlers.""" + + name: str = "" + description: str = "" + + def __init__(self, dataset: FrameDataset): + """Initialize tool handler.""" + self.dataset = dataset + + async def execute(self, **kwargs) -> dict[str, Any]: + """Execute the tool with given parameters.""" + raise NotImplementedError + + def get_input_schema(self) -> dict[str, Any]: + """Get the input schema for this tool.""" + # Override in subclasses + return {"type": "object", "properties": {}, "additionalProperties": False} + + +class GetDatasetStatsHandler(ToolHandler): + """Handler for get_dataset_stats tool.""" + + name = "get_dataset_stats" + description = "Get comprehensive dataset statistics including storage, content, and index metrics" + + def __init__(self, dataset: FrameDataset): + """Initialize handler.""" + super().__init__(dataset) + self.stats_collector = StatsCollector(dataset) + + def get_input_schema(self) -> dict[str, Any]: + """Get input schema for get_dataset_stats.""" + return { + "type": "object", + "properties": { + "include_details": { + "type": "boolean", + "default": True, + "description": "Include detailed analysis", + }, + "include_fragments": { + "type": "boolean", + "default": True, + "description": "Include fragment-level statistics", + }, + "sample_size": { + "type": "integer", + "minimum": 100, + "maximum": 100000, + "description": "Sample size for expensive operations (default: full scan)", + }, + }, + "additionalProperties": False, + } + + async def execute(self, **kwargs) -> dict[str, Any]: + """Execute get_dataset_stats tool. + + Args: + include_details: Include detailed analysis (default: True) + include_fragments: Include fragment-level stats (default: True) + sample_size: Sample size for expensive operations (default: None - full scan) + + Returns: + Comprehensive dataset statistics + """ + include_details = kwargs.get("include_details", True) + include_fragments = kwargs.get("include_fragments", True) + sample_size = kwargs.get("sample_size") + + try: + # Collect statistics + stats = await self.stats_collector.collect_stats( + include_content=include_details, + include_fragments=include_fragments, + include_relationships=include_details, + sample_size=sample_size, + ) + + return { + "success": True, + "stats": stats.to_dict(), + "timestamp": datetime.now().isoformat(), + } + + except Exception as e: + raise ToolError(f"Failed to collect dataset statistics: {str(e)}") + + +class AnalyzeUsageHandler(ToolHandler): + """Handler for analyze_usage tool.""" + + name = "analyze_usage" + description = "Analyze dataset usage patterns and access frequencies" + + def __init__(self, dataset: FrameDataset): + """Initialize handler.""" + super().__init__(dataset) + self.usage_analyzer = UsageAnalyzer(dataset) + # In production, this would load from persistent storage + self._simulate_usage_data() + + def get_input_schema(self) -> dict[str, Any]: + """Get input schema for analyze_usage.""" + return { + "type": "object", + "properties": { + "time_range": { + "type": "string", + "default": "7d", + "description": "Analysis period (e.g., '7d', '24h', '30d')", + }, + "group_by": { + "type": "string", + "enum": ["hour", "day", "week"], + "default": "hour", + "description": "Grouping period for temporal analysis", + }, + "include_patterns": { + "type": "boolean", + "default": True, + "description": "Include pattern analysis and recommendations", + }, + }, + "additionalProperties": False, + } + + def _simulate_usage_data(self): + """Simulate usage data for demonstration.""" + # In production, this would be collected from actual usage + import random + + # Get some document IDs + scanner = self.dataset.scanner(columns=["id"], limit=100) + doc_ids = [] + for batch in scanner.to_batches(): + doc_ids.extend(batch.column("id").to_pylist()) + + # Simulate access patterns + for _ in range(500): + if doc_ids: + doc_id = random.choice(doc_ids) + operation = random.choice(["read", "search", "update"]) + self.usage_analyzer.record_access(doc_id, operation) + + async def execute(self, **kwargs) -> dict[str, Any]: + """Execute analyze_usage tool. + + Args: + time_range: Analysis period (e.g., "7d", "24h", "30d") + group_by: Grouping period - "hour", "day", "week" (default: "hour") + include_patterns: Include pattern analysis (default: True) + + Returns: + Usage analysis results + """ + time_range_str = kwargs.get("time_range", "7d") + group_by = kwargs.get("group_by", "hour") + include_patterns = kwargs.get("include_patterns", True) + + try: + # Parse time range + time_range = self._parse_time_range(time_range_str) + + # Analyze usage + analysis = await self.usage_analyzer.analyze_usage( + time_range=time_range, + group_by=group_by, + include_patterns=include_patterns, + ) + + return { + "success": True, + "analysis": analysis, + "period": time_range_str, + "timestamp": datetime.now().isoformat(), + } + + except Exception as e: + raise ToolError(f"Failed to analyze usage: {str(e)}") + + def _parse_time_range(self, time_range_str: str) -> timedelta: + """Parse time range string to timedelta.""" + if time_range_str.endswith("h"): + hours = int(time_range_str[:-1]) + return timedelta(hours=hours) + elif time_range_str.endswith("d"): + days = int(time_range_str[:-1]) + return timedelta(days=days) + elif time_range_str.endswith("w"): + weeks = int(time_range_str[:-1]) + return timedelta(weeks=weeks) + else: + raise ValueError(f"Invalid time range format: {time_range_str}") + + +class QueryPerformanceHandler(ToolHandler): + """Handler for query_performance tool.""" + + name = "query_performance" + description = "Analyze query performance and identify optimization opportunities" + + def __init__(self, dataset: FrameDataset): + """Initialize handler.""" + super().__init__(dataset) + self.query_analyzer = QueryAnalyzer(dataset) + # Simulate some query history + self._simulate_query_data() + + def get_input_schema(self) -> dict[str, Any]: + """Get input schema for query_performance.""" + return { + "type": "object", + "properties": { + "time_range": { + "type": "string", + "default": "7d", + "description": "Analysis period (e.g., '7d', '24h')", + }, + "query_type": { + "type": "string", + "enum": ["vector", "text", "hybrid", "filter"], + "description": "Filter by query type", + }, + "min_duration_ms": { + "type": "number", + "minimum": 0, + "default": 0, + "description": "Minimum query duration to include (ms)", + }, + }, + "additionalProperties": False, + } + + def _simulate_query_data(self): + """Simulate query performance data.""" + import random + from .analyzer import QueryMetrics + + # Simulate various query patterns + query_types = ["vector", "text", "filter", "hybrid"] + + for _ in range(200): + query_type = random.choice(query_types) + + # Simulate performance characteristics + if query_type == "vector": + duration = random.gauss(50, 20) # Fast with index + rows_scanned = random.randint(100, 1000) + index_used = random.random() > 0.2 + elif query_type == "text": + duration = random.gauss(100, 50) + rows_scanned = random.randint(500, 5000) + index_used = random.random() > 0.5 + elif query_type == "filter": + duration = random.gauss(200, 100) + rows_scanned = random.randint(1000, 10000) + index_used = random.random() > 0.7 + else: # hybrid + duration = random.gauss(150, 75) + rows_scanned = random.randint(500, 2000) + index_used = random.random() > 0.4 + + metrics = QueryMetrics( + query_type=query_type, + query_text=f"sample {query_type} query", + filter_expression="record_type = 'document'" + if random.random() > 0.5 + else None, + duration_ms=max(1, duration), + rows_scanned=rows_scanned, + rows_returned=random.randint(1, min(100, rows_scanned)), + index_used=index_used, + timestamp=datetime.now() + - timedelta( + hours=random.randint(0, 168) # Last week + ), + ) + + self.query_analyzer.record_query(metrics) + + async def execute(self, **kwargs) -> dict[str, Any]: + """Execute query_performance tool. + + Args: + time_range: Analysis period (e.g., "7d", "24h") + query_type: Filter by type - "vector", "text", "hybrid", "filter" + min_duration_ms: Minimum query duration to include + + Returns: + Query performance analysis + """ + time_range_str = kwargs.get("time_range", "7d") + query_type = kwargs.get("query_type") + min_duration_ms = kwargs.get("min_duration_ms", 0) + + try: + # Parse time range + time_range = None + if time_range_str: + if time_range_str.endswith("h"): + hours = int(time_range_str[:-1]) + time_range = timedelta(hours=hours) + elif time_range_str.endswith("d"): + days = int(time_range_str[:-1]) + time_range = timedelta(days=days) + + # Analyze performance + analysis = await self.query_analyzer.analyze_performance( + time_range=time_range, + query_type=query_type, + min_duration_ms=min_duration_ms, + ) + + return { + "success": True, + "performance": analysis, + "period": time_range_str, + "timestamp": datetime.now().isoformat(), + } + + except Exception as e: + raise ToolError(f"Failed to analyze query performance: {str(e)}") + + +class RelationshipAnalysisHandler(ToolHandler): + """Handler for relationship_analysis tool.""" + + name = "relationship_analysis" + description = "Analyze document relationships and graph structure" + + def __init__(self, dataset: FrameDataset): + """Initialize handler.""" + super().__init__(dataset) + self.relationship_analyzer = RelationshipAnalyzer(dataset) + + def get_input_schema(self) -> dict[str, Any]: + """Get input schema for relationship_analysis.""" + return { + "type": "object", + "properties": { + "max_depth": { + "type": "integer", + "minimum": 1, + "maximum": 10, + "default": 3, + "description": "Maximum traversal depth", + }, + "relationship_types": { + "type": "array", + "items": {"type": "string"}, + "description": "Types to analyze (default: all)", + }, + "include_orphans": { + "type": "boolean", + "default": True, + "description": "Include orphaned documents", + }, + }, + "additionalProperties": False, + } + + async def execute(self, **kwargs) -> dict[str, Any]: + """Execute relationship_analysis tool. + + Args: + max_depth: Maximum traversal depth (default: 3) + relationship_types: List of types to analyze (default: all) + include_orphans: Include orphaned documents (default: True) + + Returns: + Relationship analysis results + """ + max_depth = kwargs.get("max_depth", 3) + relationship_types = kwargs.get("relationship_types") + include_orphans = kwargs.get("include_orphans", True) + + try: + # Analyze relationships + analysis = await self.relationship_analyzer.analyze_relationships( + max_depth=max_depth, + relationship_types=relationship_types, + include_orphans=include_orphans, + ) + + return { + "success": True, + "analysis": analysis, + "timestamp": datetime.now().isoformat(), + } + + except Exception as e: + raise ToolError(f"Failed to analyze relationships: {str(e)}") + + +class OptimizeStorageHandler(ToolHandler): + """Handler for optimize_storage tool.""" + + name = "optimize_storage" + description = "Optimize dataset storage through compaction and cleanup" + + def __init__(self, dataset: FrameDataset): + """Initialize handler.""" + super().__init__(dataset) + self.storage_optimizer = StorageOptimizer(dataset) + + def get_input_schema(self) -> dict[str, Any]: + """Get input schema for optimize_storage.""" + return { + "type": "object", + "properties": { + "operations": { + "type": "array", + "items": { + "type": "string", + "enum": ["compact", "vacuum", "reindex"], + }, + "default": ["compact", "vacuum"], + "description": "Operations to perform", + }, + "dry_run": { + "type": "boolean", + "default": True, + "description": "Preview changes without applying", + }, + "target_version": { + "type": "integer", + "minimum": 0, + "description": "Target version for cleanup", + }, + }, + "additionalProperties": False, + } + + async def execute(self, **kwargs) -> dict[str, Any]: + """Execute optimize_storage tool. + + Args: + operations: List of operations - "compact", "vacuum", "reindex" + dry_run: Preview changes without applying (default: True) + target_version: Target version for cleanup + + Returns: + Optimization results + """ + operations = kwargs.get("operations", ["compact", "vacuum"]) + dry_run = kwargs.get("dry_run", True) + target_version = kwargs.get("target_version") + + try: + # Validate operations + valid_ops = {"compact", "vacuum", "reindex"} + invalid_ops = set(operations) - valid_ops + if invalid_ops: + raise ValueError(f"Invalid operations: {invalid_ops}") + + # Run optimization + results = await self.storage_optimizer.optimize_storage( + operations=operations, dry_run=dry_run, target_version=target_version + ) + + return { + "success": True, + "results": results, + "dry_run": dry_run, + "timestamp": datetime.now().isoformat(), + } + + except Exception as e: + raise ToolError(f"Failed to optimize storage: {str(e)}") + + +class IndexRecommendationsHandler(ToolHandler): + """Handler for index_recommendations tool.""" + + name = "index_recommendations" + description = "Get recommendations for index improvements" + + def __init__(self, dataset: FrameDataset): + """Initialize handler.""" + super().__init__(dataset) + self.index_advisor = IndexAdvisor(dataset) + # Simulate some query patterns + self._simulate_query_patterns() + + def get_input_schema(self) -> dict[str, Any]: + """Get input schema for index_recommendations.""" + return { + "type": "object", + "properties": { + "analyze_queries": { + "type": "boolean", + "default": True, + "description": "Analyze recent query patterns", + }, + "workload_type": { + "type": "string", + "enum": ["search", "analytics", "mixed"], + "default": "mixed", + "description": "Type of workload to optimize for", + }, + }, + "additionalProperties": False, + } + + def _simulate_query_patterns(self): + """Simulate query patterns for analysis.""" + # Common query patterns + patterns = [ + ("record_type = 'document'", ["record_type"]), + ("created_at > '2024-01-01'", ["created_at"]), + ("source_type = 'web'", ["source_type"]), + ("id = '123'", ["id"]), + ] + + for filter_expr, fields in patterns: + for _ in range(10): + self.index_advisor.record_query_pattern(filter_expr, fields) + + async def execute(self, **kwargs) -> dict[str, Any]: + """Execute index_recommendations tool. + + Args: + analyze_queries: Analyze recent query patterns (default: True) + workload_type: Type of workload - "search", "analytics", "mixed" + + Returns: + Index recommendations + """ + analyze_queries = kwargs.get("analyze_queries", True) + workload_type = kwargs.get("workload_type", "mixed") + + try: + # Validate workload type + if workload_type not in ["search", "analytics", "mixed"]: + raise ValueError(f"Invalid workload type: {workload_type}") + + # Get recommendations + recommendations = await self.index_advisor.get_recommendations( + analyze_queries=analyze_queries, workload_type=workload_type + ) + + return { + "success": True, + "recommendations": recommendations, + "workload_type": workload_type, + "timestamp": datetime.now().isoformat(), + } + + except Exception as e: + raise ToolError(f"Failed to get index recommendations: {str(e)}") + + +class BenchmarkOperationsHandler(ToolHandler): + """Handler for benchmark_operations tool.""" + + name = "benchmark_operations" + description = "Benchmark dataset operations to measure performance" + + def __init__(self, dataset: FrameDataset): + """Initialize handler.""" + super().__init__(dataset) + self.benchmark = PerformanceBenchmark(dataset) + + def get_input_schema(self) -> dict[str, Any]: + """Get input schema for benchmark_operations.""" + return { + "type": "object", + "properties": { + "operations": { + "type": "array", + "items": { + "type": "string", + "enum": ["search", "insert", "update", "scan"], + }, + "default": ["search", "scan"], + "description": "Operations to benchmark", + }, + "sample_size": { + "type": "integer", + "minimum": 1, + "maximum": 10000, + "default": 100, + "description": "Number of operations per benchmark", + }, + "concurrency": { + "type": "integer", + "minimum": 1, + "maximum": 100, + "default": 1, + "description": "Number of concurrent operations", + }, + }, + "additionalProperties": False, + } + + async def execute(self, **kwargs) -> dict[str, Any]: + """Execute benchmark_operations tool. + + Args: + operations: List of operations - "search", "insert", "update", "scan" + sample_size: Number of operations per benchmark (default: 100) + concurrency: Number of concurrent operations (default: 1) + + Returns: + Benchmark results + """ + operations = kwargs.get("operations", ["search", "scan"]) + sample_size = kwargs.get("sample_size", 100) + concurrency = kwargs.get("concurrency", 1) + + try: + # Validate operations + valid_ops = {"search", "insert", "update", "scan"} + invalid_ops = set(operations) - valid_ops + if invalid_ops: + raise ValueError(f"Invalid operations: {invalid_ops}") + + # Validate parameters + if sample_size < 1 or sample_size > 10000: + raise ValueError("sample_size must be between 1 and 10000") + + if concurrency < 1 or concurrency > 100: + raise ValueError("concurrency must be between 1 and 100") + + # Run benchmarks + results = await self.benchmark.benchmark_operations( + operations=operations, sample_size=sample_size, concurrency=concurrency + ) + + return { + "success": True, + "benchmarks": results, + "timestamp": datetime.now().isoformat(), + } + + except Exception as e: + raise ToolError(f"Failed to run benchmarks: {str(e)}") + + +class ExportMetricsHandler(ToolHandler): + """Handler for export_metrics tool.""" + + name = "export_metrics" + description = "Export dataset metrics for monitoring systems" + + def __init__(self, dataset: FrameDataset): + """Initialize handler.""" + super().__init__(dataset) + self.stats_collector = StatsCollector(dataset) + + def get_input_schema(self) -> dict[str, Any]: + """Get input schema for export_metrics.""" + return { + "type": "object", + "properties": { + "format": { + "type": "string", + "enum": ["prometheus", "json", "csv"], + "default": "json", + "description": "Export format", + }, + "metrics": { + "type": "array", + "items": {"type": "string"}, + "description": "Specific metrics to export (default: all)", + }, + "labels": { + "type": "object", + "description": "Additional labels to include", + }, + }, + "additionalProperties": False, + } + + async def execute(self, **kwargs) -> dict[str, Any]: + """Execute export_metrics tool. + + Args: + format: Export format - "prometheus", "json", "csv" (default: "json") + metrics: List of specific metrics to export (default: all) + labels: Additional labels to include + + Returns: + Formatted metrics ready for export + """ + export_format = kwargs.get("format", "json") + metrics_filter = kwargs.get("metrics", []) + labels = kwargs.get("labels", {}) + + try: + # Validate format + if export_format not in ["prometheus", "json", "csv"]: + raise ValueError(f"Invalid format: {export_format}") + + # Collect current stats + stats = await self.stats_collector.collect_stats( + include_content=False, # Quick stats only + include_fragments=True, + include_relationships=False, + sample_size=1000, # Sample for speed + ) + + stats_dict = stats.to_dict() + + # Filter metrics if specified + if metrics_filter: + filtered_stats = {} + for metric in metrics_filter: + if metric in stats_dict: + filtered_stats[metric] = stats_dict[metric] + stats_dict = filtered_stats + + # Format based on type + if export_format == "prometheus": + formatted = self._format_prometheus(stats_dict, labels) + elif export_format == "csv": + formatted = self._format_csv(stats_dict) + else: # json + formatted = { + "metrics": stats_dict, + "labels": labels, + "timestamp": datetime.now().isoformat(), + } + + return { + "success": True, + "format": export_format, + "data": formatted, + "timestamp": datetime.now().isoformat(), + } + + except Exception as e: + raise ToolError(f"Failed to export metrics: {str(e)}") + + def _format_prometheus(self, stats: dict[str, Any], labels: dict[str, str]) -> str: + """Format metrics in Prometheus format.""" + lines = [] + + # Format labels + label_str = "" + if labels: + label_parts = [f'{k}="{v}"' for k, v in labels.items()] + label_str = "{" + ",".join(label_parts) + "}" + + # Flatten stats and convert to Prometheus format + def flatten_dict(d: dict, prefix: str = ""): + for key, value in d.items(): + full_key = f"{prefix}_{key}" if prefix else key + full_key = full_key.replace(".", "_").replace(" ", "_") + + if isinstance(value, dict): + flatten_dict(value, full_key) + elif isinstance(value, (int, float)): + lines.append(f"# TYPE contextframe_{full_key} gauge") + lines.append(f"contextframe_{full_key}{label_str} {value}") + elif ( + isinstance(value, list) + and value + and isinstance(value[0], (int, float)) + ): + lines.append(f"# TYPE contextframe_{full_key} gauge") + lines.append(f"contextframe_{full_key}{label_str} {len(value)}") + + flatten_dict(stats) + + return "\n".join(lines) + + def _format_csv(self, stats: dict[str, Any]) -> str: + """Format metrics as CSV.""" + rows = [] + + # Flatten stats + def flatten_dict(d: dict, prefix: str = ""): + row = {} + for key, value in d.items(): + full_key = f"{prefix}.{key}" if prefix else key + + if isinstance(value, dict): + row.update(flatten_dict(value, full_key)) + elif isinstance(value, (int, float, str)): + row[full_key] = value + elif isinstance(value, list): + row[full_key] = len(value) + else: + row[full_key] = str(value) + return row + + flat_stats = flatten_dict(stats) + + # Create CSV + if flat_stats: + headers = list(flat_stats.keys()) + rows.append(",".join(headers)) + rows.append(",".join(str(flat_stats[h]) for h in headers)) + + return "\n".join(rows) diff --git a/contextframe/mcp/errors.py b/contextframe/mcp/errors.py index 0d73ebe..583d5a0 100644 --- a/contextframe/mcp/errors.py +++ b/contextframe/mcp/errors.py @@ -2,7 +2,6 @@ from typing import Any, Dict, Optional - # Standard JSON-RPC 2.0 error codes PARSE_ERROR = -32700 INVALID_REQUEST = -32600 @@ -21,23 +20,15 @@ class MCPError(Exception): """Base class for MCP errors with JSON-RPC error formatting.""" - def __init__( - self, - code: int, - message: str, - data: Optional[Any] = None - ): + def __init__(self, code: int, message: str, data: Any | None = None): super().__init__(message) self.code = code self.message = message self.data = data - def to_json_rpc(self) -> Dict[str, Any]: + def to_json_rpc(self) -> dict[str, Any]: """Convert error to JSON-RPC error format.""" - error_dict = { - "code": self.code, - "message": self.message - } + error_dict = {"code": self.code, "message": self.message} if self.data is not None: error_dict["data"] = self.data return error_dict @@ -46,23 +37,15 @@ def to_json_rpc(self) -> Dict[str, Any]: class ParseError(MCPError): """Invalid JSON was received by the server.""" - def __init__(self, data: Optional[Any] = None): - super().__init__( - PARSE_ERROR, - "Parse error", - data - ) + def __init__(self, data: Any | None = None): + super().__init__(PARSE_ERROR, "Parse error", data) class InvalidRequest(MCPError): """The JSON sent is not a valid Request object.""" - def __init__(self, data: Optional[Any] = None): - super().__init__( - INVALID_REQUEST, - "Invalid Request", - data - ) + def __init__(self, data: Any | None = None): + super().__init__(INVALID_REQUEST, "Invalid Request", data) class MethodNotFound(MCPError): @@ -70,32 +53,22 @@ class MethodNotFound(MCPError): def __init__(self, method: str): super().__init__( - METHOD_NOT_FOUND, - f"Method not found: {method}", - {"method": method} + METHOD_NOT_FOUND, f"Method not found: {method}", {"method": method} ) class InvalidParams(MCPError): """Invalid method parameter(s).""" - def __init__(self, message: str, data: Optional[Any] = None): - super().__init__( - INVALID_PARAMS, - f"Invalid params: {message}", - data - ) + def __init__(self, message: str, data: Any | None = None): + super().__init__(INVALID_PARAMS, f"Invalid params: {message}", data) class InternalError(MCPError): """Internal JSON-RPC error.""" - def __init__(self, message: str, data: Optional[Any] = None): - super().__init__( - INTERNAL_ERROR, - f"Internal error: {message}", - data - ) + def __init__(self, message: str, data: Any | None = None): + super().__init__(INTERNAL_ERROR, f"Internal error: {message}", data) class DatasetNotFound(MCPError): @@ -103,9 +76,7 @@ class DatasetNotFound(MCPError): def __init__(self, path: str): super().__init__( - DATASET_NOT_FOUND, - f"Dataset not found: {path}", - {"path": path} + DATASET_NOT_FOUND, f"Dataset not found: {path}", {"path": path} ) @@ -116,19 +87,15 @@ def __init__(self, document_id: str): super().__init__( DOCUMENT_NOT_FOUND, f"Document not found: {document_id}", - {"document_id": document_id} + {"document_id": document_id}, ) class EmbeddingError(MCPError): """Error generating embeddings.""" - def __init__(self, message: str, data: Optional[Any] = None): - super().__init__( - EMBEDDING_ERROR, - f"Embedding error: {message}", - data - ) + def __init__(self, message: str, data: Any | None = None): + super().__init__(EMBEDDING_ERROR, f"Embedding error: {message}", data) class InvalidSearchType(MCPError): @@ -138,7 +105,7 @@ def __init__(self, search_type: str): super().__init__( INVALID_SEARCH_TYPE, f"Invalid search type: {search_type}", - {"search_type": search_type, "valid_types": ["vector", "text", "hybrid"]} + {"search_type": search_type, "valid_types": ["vector", "text", "hybrid"]}, ) @@ -149,5 +116,12 @@ def __init__(self, message: str, filter_expr: str): super().__init__( FILTER_ERROR, f"Filter error: {message}", - {"filter": filter_expr, "error": message} - ) \ No newline at end of file + {"filter": filter_expr, "error": message}, + ) + + +class ToolError(MCPError): + """Error executing MCP tool.""" + + def __init__(self, message: str, data: Any | None = None): + super().__init__(INTERNAL_ERROR, f"Tool error: {message}", data) diff --git a/contextframe/mcp/transports/http/__init__.py b/contextframe/mcp/transports/http/__init__.py new file mode 100644 index 0000000..a376b4c --- /dev/null +++ b/contextframe/mcp/transports/http/__init__.py @@ -0,0 +1,7 @@ +"""HTTP transport implementation for MCP server.""" + +from .adapter import HttpAdapter +from .server import create_http_server +from .sse import SSEStream + +__all__ = ["HttpAdapter", "create_http_server", "SSEStream"] diff --git a/contextframe/mcp/transports/http/adapter.py b/contextframe/mcp/transports/http/adapter.py new file mode 100644 index 0000000..ea4bedd --- /dev/null +++ b/contextframe/mcp/transports/http/adapter.py @@ -0,0 +1,231 @@ +"""HTTP transport adapter implementation.""" + +import asyncio +import json +import logging +import uuid +from collections.abc import AsyncIterator +from contextframe.mcp.core.streaming import SSEStreamingAdapter +from contextframe.mcp.core.transport import Progress, Subscription, TransportAdapter +from contextlib import asynccontextmanager +from typing import Any, Dict, Optional, Set + +logger = logging.getLogger(__name__) + + +class HttpAdapter(TransportAdapter): + """HTTP transport adapter with SSE streaming support. + + This adapter enables the MCP server to handle HTTP requests and + provide real-time streaming via Server-Sent Events (SSE). + """ + + def __init__(self): + super().__init__() + self._streaming = SSEStreamingAdapter() + self._active_streams: dict[str, SSEStream] = {} + self._operation_progress: dict[str, asyncio.Queue] = {} + self._subscription_queues: dict[str, asyncio.Queue] = {} + self._active_operations: set[str] = set() + + async def initialize(self) -> None: + """Initialize HTTP transport.""" + logger.info("HTTP transport initialized") + + async def shutdown(self) -> None: + """Shutdown HTTP transport and close all streams.""" + # Close all active SSE streams + for stream_id in list(self._active_streams.keys()): + await self.close_stream(stream_id) + + # Clear all queues + self._operation_progress.clear() + self._subscription_queues.clear() + self._active_operations.clear() + + logger.info("HTTP transport shutdown") + + async def send_message(self, message: dict[str, Any]) -> None: + """HTTP doesn't use send_message - responses are returned directly.""" + # In HTTP, responses are returned from request handlers + # This method is here for interface compatibility + pass + + async def receive_message(self) -> dict[str, Any] | None: + """HTTP doesn't use receive_message - requests come via HTTP.""" + # In HTTP, messages come through HTTP request handlers + # This method is here for interface compatibility + return None + + async def send_progress(self, progress: Progress) -> None: + """Send progress update via SSE to relevant streams.""" + await super().send_progress(progress) + + # Send to operation-specific progress streams + operation_id = ( + progress.details.get("operation_id") if progress.details else None + ) + if operation_id and operation_id in self._operation_progress: + queue = self._operation_progress[operation_id] + await queue.put( + { + "type": "progress", + "data": { + "operation": progress.operation, + "current": progress.current, + "total": progress.total, + "status": progress.status, + "details": progress.details, + }, + } + ) + + async def handle_subscription( + self, subscription: Subscription + ) -> AsyncIterator[dict[str, Any]]: + """Stream dataset changes via SSE. + + Unlike stdio which uses polling, HTTP streams changes in real-time. + """ + self._subscriptions[subscription.id] = subscription + + # Create a queue for this subscription + queue = asyncio.Queue() + self._subscription_queues[subscription.id] = queue + + try: + # Yield initial subscription confirmation + yield { + "type": "subscription_created", + "subscription_id": subscription.id, + "resource_type": subscription.resource_type, + "filter": subscription.filter, + } + + # Stream changes as they arrive + while subscription.id in self._subscriptions: + try: + # Wait for changes with timeout + change = await asyncio.wait_for(queue.get(), timeout=30.0) + yield change + except TimeoutError: + # Send keepalive + yield {"type": "keepalive", "subscription_id": subscription.id} + + finally: + # Cleanup + if subscription.id in self._subscription_queues: + del self._subscription_queues[subscription.id] + self.cancel_subscription(subscription.id) + + def register_stream(self, stream_id: str, stream: "SSEStream") -> None: + """Register an active SSE stream.""" + self._active_streams[stream_id] = stream + + async def close_stream(self, stream_id: str) -> None: + """Close and unregister an SSE stream.""" + if stream_id in self._active_streams: + stream = self._active_streams[stream_id] + await stream.close() + del self._active_streams[stream_id] + + async def create_operation(self, operation_type: str) -> str: + """Create a new operation for progress tracking.""" + operation_id = str(uuid.uuid4()) + self._active_operations.add(operation_id) + self._operation_progress[operation_id] = asyncio.Queue() + return operation_id + + async def complete_operation(self, operation_id: str) -> None: + """Mark an operation as complete.""" + if operation_id in self._active_operations: + self._active_operations.remove(operation_id) + + # Send completion event + if operation_id in self._operation_progress: + queue = self._operation_progress[operation_id] + await queue.put({"type": "complete", "operation_id": operation_id}) + + @asynccontextmanager + async def operation_context(self, operation_type: str): + """Context manager for operations with automatic cleanup.""" + operation_id = await self.create_operation(operation_type) + try: + yield operation_id + finally: + await self.complete_operation(operation_id) + # Cleanup queue after a delay to ensure clients receive completion + await asyncio.sleep(5.0) + if operation_id in self._operation_progress: + del self._operation_progress[operation_id] + + async def stream_operation_progress( + self, operation_id: str + ) -> AsyncIterator[dict[str, Any]]: + """Stream progress updates for a specific operation.""" + if operation_id not in self._operation_progress: + yield {"type": "error", "error": f"Operation {operation_id} not found"} + return + + queue = self._operation_progress[operation_id] + + while operation_id in self._active_operations or not queue.empty(): + try: + # Wait for progress with timeout + event = await asyncio.wait_for(queue.get(), timeout=30.0) + yield event + + # Stop if operation is complete + if event.get("type") == "complete": + break + + except TimeoutError: + # Send keepalive + yield {"type": "keepalive", "operation_id": operation_id} + + async def notify_change(self, change: dict[str, Any]) -> None: + """Notify all relevant subscriptions about a change.""" + for sub_id, subscription in self._subscriptions.items(): + # Check if change matches subscription filter + if self._matches_subscription(change, subscription): + if sub_id in self._subscription_queues: + await self._subscription_queues[sub_id].put( + {"type": "change", "subscription_id": sub_id, "change": change} + ) + + def _matches_subscription( + self, change: dict[str, Any], subscription: Subscription + ) -> bool: + """Check if a change matches a subscription's filters.""" + # Check resource type + if change.get("resource_type") != subscription.resource_type: + return False + + # Apply additional filters if present + if subscription.filter: + # This would be more sophisticated in production + # For now, simple string matching + filter_dict = ( + json.loads(subscription.filter) + if isinstance(subscription.filter, str) + else subscription.filter + ) + for key, value in filter_dict.items(): + if change.get(key) != value: + return False + + return True + + @property + def supports_streaming(self) -> bool: + """HTTP transport supports true streaming via SSE.""" + return True + + @property + def transport_type(self) -> str: + """Transport type identifier.""" + return "http" + + def get_streaming_adapter(self) -> SSEStreamingAdapter: + """Get the streaming adapter for this transport.""" + return self._streaming diff --git a/contextframe/mcp/transports/http/auth.py b/contextframe/mcp/transports/http/auth.py new file mode 100644 index 0000000..23fd212 --- /dev/null +++ b/contextframe/mcp/transports/http/auth.py @@ -0,0 +1,222 @@ +"""OAuth 2.1 authentication for HTTP transport.""" + +import logging +from datetime import datetime, timedelta +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt +from pydantic import BaseModel +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + +# OAuth2 scheme +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) + +# Configuration (would come from config in production) +SECRET_KEY = "your-secret-key-here" # Should be loaded from environment +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 + + +class Token(BaseModel): + """OAuth token response.""" + + access_token: str + token_type: str + expires_in: int + + +class TokenData(BaseModel): + """Token payload data.""" + + sub: str | None = None + exp: datetime | None = None + scopes: list[str] = [] + + +class User(BaseModel): + """User model.""" + + id: str + username: str + email: str | None = None + disabled: bool = False + scopes: list[str] = [] + + +class OAuth2Handler: + """OAuth 2.1 authentication handler with PKCE support.""" + + def __init__( + self, + secret_key: str = SECRET_KEY, + algorithm: str = ALGORITHM, + issuer: str | None = None, + audience: str | None = None, + ): + self.secret_key = secret_key + self.algorithm = algorithm + self.issuer = issuer or "contextframe-mcp" + self.audience = audience or "contextframe-mcp" + + def create_access_token( + self, data: dict, expires_delta: timedelta | None = None + ) -> str: + """Create a JWT access token.""" + to_encode = data.copy() + + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + + to_encode.update( + { + "exp": expire, + "iss": self.issuer, + "aud": self.audience, + "iat": datetime.utcnow(), + } + ) + + encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) + return encoded_jwt + + async def verify_token(self, token: str) -> TokenData: + """Verify and decode JWT token.""" + try: + payload = jwt.decode( + token, + self.secret_key, + algorithms=[self.algorithm], + audience=self.audience, + issuer=self.issuer, + ) + + # Extract token data + sub: str = payload.get("sub") + if sub is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token: missing subject", + ) + + token_data = TokenData( + sub=sub, + exp=datetime.fromtimestamp(payload.get("exp", 0)), + scopes=payload.get("scopes", []), + ) + + return token_data + + except JWTError as e: + logger.error(f"Token verification failed: {e}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + async def get_current_user(self, token: str) -> User: + """Get current user from token.""" + token_data = await self.verify_token(token) + + # In production, this would look up the user in a database + # For now, we create a user from token data + user = User( + id=token_data.sub, username=token_data.sub, scopes=token_data.scopes + ) + + if user.disabled: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="User is disabled" + ) + + return user + + async def check_permissions(self, user: User, resource: str, action: str) -> bool: + """Check if user has permission for resource/action. + + Args: + user: The authenticated user + resource: Resource type (e.g., "documents", "tools") + action: Action type (e.g., "read", "write", "execute") + + Returns: + True if user has permission, False otherwise + """ + # Check for admin scope + if "admin" in user.scopes: + return True + + # Check specific permission scope + required_scope = f"{resource}:{action}" + if required_scope in user.scopes: + return True + + # Check wildcard scopes + if f"{resource}:*" in user.scopes: + return True + + if f"*:{action}" in user.scopes: + return True + + return False + + +# Global handler instance +auth_handler = OAuth2Handler() + + +async def get_current_user( + token: str | None = Depends(oauth2_scheme), +) -> dict[str, Any] | None: + """Dependency to get current user from token. + + Returns None if auth is disabled or token is not provided. + """ + if not token: + return None + + user = await auth_handler.get_current_user(token) + return user.model_dump() + + +async def require_auth( + user: dict[str, Any] | None = Depends(get_current_user), +) -> dict[str, Any]: + """Dependency that requires authentication.""" + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required", + headers={"WWW-Authenticate": "Bearer"}, + ) + return user + + +def create_scope_checker(required_scopes: list[str]): + """Create a dependency that checks for required scopes.""" + + async def check_scopes( + user: dict[str, Any] = Depends(require_auth), + ) -> dict[str, Any]: + user_scopes = set(user.get("scopes", [])) + required = set(required_scopes) + + if not required.issubset(user_scopes): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Required scopes: {required_scopes}", + ) + + return user + + return check_scopes + + +# Example usage in routes: +# @app.post("/tools/call", dependencies=[Depends(create_scope_checker(["tools:execute"]))]) +# async def call_tool(...): +# ... diff --git a/contextframe/mcp/transports/http/config.py b/contextframe/mcp/transports/http/config.py new file mode 100644 index 0000000..b088dcf --- /dev/null +++ b/contextframe/mcp/transports/http/config.py @@ -0,0 +1,297 @@ +"""Configuration for HTTP transport.""" + +import json +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class HTTPTransportConfig: + """Configuration for HTTP transport. + + This configuration can be loaded from environment variables, + JSON files, or passed directly. + """ + + # Server settings + host: str = "0.0.0.0" + port: int = 8080 + workers: int = 1 + + # CORS settings + cors_enabled: bool = True + cors_origins: list[str] = field(default_factory=lambda: ["*"]) + cors_credentials: bool = True + cors_methods: list[str] = field(default_factory=lambda: ["*"]) + cors_headers: list[str] = field(default_factory=lambda: ["*"]) + + # Authentication settings + auth_enabled: bool = False + auth_secret_key: str | None = None + auth_algorithm: str = "HS256" + auth_issuer: str = "contextframe-mcp" + auth_audience: str = "contextframe-mcp" + auth_token_expire_minutes: int = 30 + + # Rate limiting settings + rate_limit_enabled: bool = True + rate_limit_requests_per_minute: int = 60 + rate_limit_burst: int = 10 + + # SSL/TLS settings + ssl_enabled: bool = False + ssl_cert: str | None = None + ssl_key: str | None = None + + # SSE settings + sse_max_connections: int = 1000 + sse_keepalive_interval: int = 25 # seconds + sse_max_age_seconds: int = 3600 # 1 hour + + # Session settings + session_enabled: bool = True + session_secret_key: str | None = None + session_max_age: int = 86400 # 24 hours + + # Performance settings + max_request_size: int = 10 * 1024 * 1024 # 10MB + request_timeout: int = 300 # 5 minutes + + @classmethod + def from_env(cls) -> "HTTPTransportConfig": + """Load configuration from environment variables. + + Environment variable format: MCP_HTTP_ + For nested settings use double underscore: MCP_HTTP_CORS__ORIGINS + """ + config = cls() + + # Helper to get env var with prefix + def get_env(key: str, default: Any = None) -> Any: + return os.environ.get(f"MCP_HTTP_{key}", default) + + # Server settings + config.host = get_env("HOST", config.host) + config.port = int(get_env("PORT", config.port)) + config.workers = int(get_env("WORKERS", config.workers)) + + # CORS settings + config.cors_enabled = get_env("CORS_ENABLED", "true").lower() == "true" + cors_origins = get_env("CORS_ORIGINS") + if cors_origins: + config.cors_origins = cors_origins.split(",") + + # Authentication settings + config.auth_enabled = get_env("AUTH_ENABLED", "false").lower() == "true" + config.auth_secret_key = get_env("AUTH_SECRET_KEY") + config.auth_issuer = get_env("AUTH_ISSUER", config.auth_issuer) + config.auth_audience = get_env("AUTH_AUDIENCE", config.auth_audience) + + # Rate limiting + config.rate_limit_enabled = ( + get_env("RATE_LIMIT_ENABLED", "true").lower() == "true" + ) + config.rate_limit_requests_per_minute = int( + get_env( + "RATE_LIMIT_REQUESTS_PER_MINUTE", config.rate_limit_requests_per_minute + ) + ) + + # SSL settings + config.ssl_enabled = get_env("SSL_ENABLED", "false").lower() == "true" + config.ssl_cert = get_env("SSL_CERT") + config.ssl_key = get_env("SSL_KEY") + + # Session settings + config.session_enabled = get_env("SESSION_ENABLED", "true").lower() == "true" + config.session_secret_key = get_env("SESSION_SECRET_KEY") + + return config + + @classmethod + def from_file(cls, path: str) -> "HTTPTransportConfig": + """Load configuration from JSON file.""" + with open(path) as f: + data = json.load(f) + + # Convert nested dicts to flat attributes + config = cls() + + # Server settings + if "server" in data: + config.host = data["server"].get("host", config.host) + config.port = data["server"].get("port", config.port) + config.workers = data["server"].get("workers", config.workers) + + # CORS settings + if "cors" in data: + config.cors_enabled = data["cors"].get("enabled", config.cors_enabled) + config.cors_origins = data["cors"].get("origins", config.cors_origins) + config.cors_credentials = data["cors"].get( + "credentials", config.cors_credentials + ) + + # Authentication settings + if "auth" in data: + config.auth_enabled = data["auth"].get("enabled", config.auth_enabled) + config.auth_secret_key = data["auth"].get("secret_key") + config.auth_issuer = data["auth"].get("issuer", config.auth_issuer) + config.auth_audience = data["auth"].get("audience", config.auth_audience) + + # Rate limiting settings + if "rate_limit" in data: + config.rate_limit_enabled = data["rate_limit"].get( + "enabled", config.rate_limit_enabled + ) + config.rate_limit_requests_per_minute = data["rate_limit"].get( + "requests_per_minute", config.rate_limit_requests_per_minute + ) + config.rate_limit_burst = data["rate_limit"].get( + "burst", config.rate_limit_burst + ) + + # SSL settings + if "ssl" in data: + config.ssl_enabled = data["ssl"].get("enabled", config.ssl_enabled) + config.ssl_cert = data["ssl"].get("cert") + config.ssl_key = data["ssl"].get("key") + + # SSE settings + if "sse" in data: + config.sse_max_connections = data["sse"].get( + "max_connections", config.sse_max_connections + ) + config.sse_keepalive_interval = data["sse"].get( + "keepalive_interval", config.sse_keepalive_interval + ) + + # Session settings + if "session" in data: + config.session_enabled = data["session"].get( + "enabled", config.session_enabled + ) + config.session_secret_key = data["session"].get("secret_key") + config.session_max_age = data["session"].get( + "max_age", config.session_max_age + ) + + return config + + def validate(self) -> list[str]: + """Validate configuration and return list of errors.""" + errors = [] + + # Validate port + if not 1 <= self.port <= 65535: + errors.append(f"Invalid port number: {self.port}") + + # Validate auth settings + if self.auth_enabled and not self.auth_secret_key: + errors.append("Authentication enabled but no secret key provided") + + # Validate SSL settings + if self.ssl_enabled: + if not self.ssl_cert: + errors.append("SSL enabled but no certificate provided") + if not self.ssl_key: + errors.append("SSL enabled but no key provided") + + # Validate session settings + if self.session_enabled and not self.session_secret_key: + errors.append("Sessions enabled but no secret key provided") + + # Validate rate limiting + if self.rate_limit_requests_per_minute < 1: + errors.append("Rate limit requests per minute must be at least 1") + + return errors + + def to_dict(self) -> dict[str, Any]: + """Convert configuration to dictionary.""" + return { + "server": { + "host": self.host, + "port": self.port, + "workers": self.workers, + }, + "cors": { + "enabled": self.cors_enabled, + "origins": self.cors_origins, + "credentials": self.cors_credentials, + "methods": self.cors_methods, + "headers": self.cors_headers, + }, + "auth": { + "enabled": self.auth_enabled, + "algorithm": self.auth_algorithm, + "issuer": self.auth_issuer, + "audience": self.auth_audience, + "token_expire_minutes": self.auth_token_expire_minutes, + }, + "rate_limit": { + "enabled": self.rate_limit_enabled, + "requests_per_minute": self.rate_limit_requests_per_minute, + "burst": self.rate_limit_burst, + }, + "ssl": { + "enabled": self.ssl_enabled, + "cert": self.ssl_cert, + "key": self.ssl_key, + }, + "sse": { + "max_connections": self.sse_max_connections, + "keepalive_interval": self.sse_keepalive_interval, + "max_age_seconds": self.sse_max_age_seconds, + }, + "session": { + "enabled": self.session_enabled, + "max_age": self.session_max_age, + }, + "performance": { + "max_request_size": self.max_request_size, + "request_timeout": self.request_timeout, + }, + } + + +# Example configuration file format: +EXAMPLE_CONFIG = """ +{ + "server": { + "host": "0.0.0.0", + "port": 8080, + "workers": 1 + }, + "cors": { + "enabled": true, + "origins": ["https://example.com", "http://localhost:3000"], + "credentials": true + }, + "auth": { + "enabled": true, + "secret_key": "your-secret-key-here", + "issuer": "contextframe-mcp", + "audience": "contextframe-mcp" + }, + "rate_limit": { + "enabled": true, + "requests_per_minute": 60, + "burst": 10 + }, + "ssl": { + "enabled": false, + "cert": "/path/to/cert.pem", + "key": "/path/to/key.pem" + }, + "sse": { + "max_connections": 1000, + "keepalive_interval": 25 + }, + "session": { + "enabled": true, + "secret_key": "your-session-secret", + "max_age": 86400 + } +} +""" diff --git a/contextframe/mcp/transports/http/security.py b/contextframe/mcp/transports/http/security.py new file mode 100644 index 0000000..bd49b1e --- /dev/null +++ b/contextframe/mcp/transports/http/security.py @@ -0,0 +1,257 @@ +"""Security features for HTTP transport including CORS and rate limiting.""" + +import asyncio +import logging +import time +from collections import defaultdict, deque +from fastapi import HTTPException, Request, Response, status +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.base import BaseHTTPMiddleware +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +class RateLimiter: + """Token bucket rate limiter implementation.""" + + def __init__( + self, + requests_per_minute: int = 60, + burst: int = 10, + cleanup_interval: int = 300, # 5 minutes + ): + self.requests_per_minute = requests_per_minute + self.burst = burst + self.cleanup_interval = cleanup_interval + + # Token buckets for each client + self._buckets: dict[str, TokenBucket] = {} + self._last_cleanup = time.time() + self._lock = asyncio.Lock() + + async def check_rate_limit(self, client_id: str) -> bool: + """Check if client has available tokens. + + Args: + client_id: Unique identifier for the client + + Returns: + True if request is allowed, False if rate limited + """ + async with self._lock: + # Periodic cleanup of old buckets + current_time = time.time() + if current_time - self._last_cleanup > self.cleanup_interval: + self._cleanup_buckets() + self._last_cleanup = current_time + + # Get or create bucket for client + if client_id not in self._buckets: + self._buckets[client_id] = TokenBucket( + capacity=self.burst, refill_rate=self.requests_per_minute / 60.0 + ) + + bucket = self._buckets[client_id] + return bucket.consume() + + def _cleanup_buckets(self) -> None: + """Remove inactive token buckets.""" + current_time = time.time() + inactive_threshold = 600 # 10 minutes + + to_remove = [] + for client_id, bucket in self._buckets.items(): + if current_time - bucket.last_update > inactive_threshold: + to_remove.append(client_id) + + for client_id in to_remove: + del self._buckets[client_id] + + if to_remove: + logger.info(f"Cleaned up {len(to_remove)} inactive rate limit buckets") + + async def get_stats(self) -> dict[str, Any]: + """Get rate limiter statistics.""" + async with self._lock: + return { + "active_buckets": len(self._buckets), + "requests_per_minute": self.requests_per_minute, + "burst_size": self.burst, + } + + +class TokenBucket: + """Token bucket for rate limiting.""" + + def __init__(self, capacity: int, refill_rate: float): + self.capacity = capacity + self.refill_rate = refill_rate + self.tokens = float(capacity) + self.last_update = time.time() + + def consume(self, tokens: int = 1) -> bool: + """Try to consume tokens from the bucket. + + Args: + tokens: Number of tokens to consume + + Returns: + True if tokens were consumed, False if insufficient tokens + """ + current_time = time.time() + + # Refill tokens based on elapsed time + elapsed = current_time - self.last_update + self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate) + self.last_update = current_time + + # Try to consume tokens + if self.tokens >= tokens: + self.tokens -= tokens + return True + + return False + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """Rate limiting middleware for FastAPI.""" + + def __init__( + self, + app, + rate_limiter: RateLimiter, + key_func: callable | None = None, + excluded_paths: list[str] | None = None, + ): + super().__init__(app) + self.rate_limiter = rate_limiter + self.key_func = key_func or self._default_key_func + self.excluded_paths = set(excluded_paths or ["/health", "/ready", "/metrics"]) + + async def dispatch(self, request: Request, call_next): + """Check rate limit before processing request.""" + # Skip rate limiting for excluded paths + if request.url.path in self.excluded_paths: + return await call_next(request) + + # Get client identifier + client_id = await self.key_func(request) + + # Check rate limit + if not await self.rate_limiter.check_rate_limit(client_id): + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Rate limit exceeded", + headers={ + "Retry-After": "60", + "X-RateLimit-Limit": str(self.rate_limiter.requests_per_minute), + }, + ) + + # Process request + response = await call_next(request) + + # Add rate limit headers + response.headers["X-RateLimit-Limit"] = str( + self.rate_limiter.requests_per_minute + ) + response.headers["X-RateLimit-Burst"] = str(self.rate_limiter.burst) + + return response + + async def _default_key_func(self, request: Request) -> str: + """Default function to extract client identifier from request.""" + # Try to get authenticated user ID + if hasattr(request.state, "user") and request.state.user: + return f"user:{request.state.user.get('id', 'unknown')}" + + # Fall back to IP address + client_ip = request.client.host if request.client else "unknown" + return f"ip:{client_ip}" + + +class SecurityConfig: + """Security configuration for HTTP transport.""" + + def __init__( + self, + cors_origins: list[str] = None, + cors_credentials: bool = True, + cors_methods: list[str] = None, + cors_headers: list[str] = None, + rate_limit_enabled: bool = True, + rate_limit_requests: int = 60, + rate_limit_burst: int = 10, + ssl_enabled: bool = False, + ssl_cert: str | None = None, + ssl_key: str | None = None, + ): + self.cors_origins = cors_origins or ["*"] + self.cors_credentials = cors_credentials + self.cors_methods = cors_methods or ["*"] + self.cors_headers = cors_headers or ["*"] + self.rate_limit_enabled = rate_limit_enabled + self.rate_limit_requests = rate_limit_requests + self.rate_limit_burst = rate_limit_burst + self.ssl_enabled = ssl_enabled + self.ssl_cert = ssl_cert + self.ssl_key = ssl_key + + def apply_to_app(self, app) -> None: + """Apply security configuration to FastAPI app.""" + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=self.cors_origins, + allow_credentials=self.cors_credentials, + allow_methods=self.cors_methods, + allow_headers=self.cors_headers, + ) + + # Add rate limiting if enabled + if self.rate_limit_enabled: + rate_limiter = RateLimiter( + requests_per_minute=self.rate_limit_requests, + burst=self.rate_limit_burst, + ) + app.add_middleware(RateLimitMiddleware, rate_limiter=rate_limiter) + + logger.info("Security configuration applied to app") + + def get_uvicorn_config(self) -> dict[str, Any]: + """Get Uvicorn SSL configuration if enabled.""" + config = {} + + if self.ssl_enabled: + if not self.ssl_cert or not self.ssl_key: + raise ValueError("SSL enabled but cert/key not provided") + + config.update( + { + "ssl_keyfile": self.ssl_key, + "ssl_certfile": self.ssl_cert, + } + ) + + return config + + +# Security headers middleware +async def add_security_headers(request: Request, call_next): + """Add security headers to responses.""" + response = await call_next(request) + + # Add security headers + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["X-XSS-Protection"] = "1; mode=block" + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + + # Add CSP for API endpoints + if request.url.path.startswith("/mcp/"): + response.headers["Content-Security-Policy"] = ( + "default-src 'none'; frame-ancestors 'none';" + ) + + return response diff --git a/contextframe/mcp/transports/http/server.py b/contextframe/mcp/transports/http/server.py new file mode 100644 index 0000000..bcb4793 --- /dev/null +++ b/contextframe/mcp/transports/http/server.py @@ -0,0 +1,310 @@ +"""FastAPI/Starlette HTTP server implementation for MCP.""" + +import asyncio +import logging +from contextframe import FrameDataset +from contextframe.mcp.handler import MessageHandler +from contextframe.mcp.schemas import ( + InitializeParams, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + ResourceReadParams, + ToolCallParams, +) +from contextframe.mcp.transports.http.adapter import HttpAdapter +from contextframe.mcp.transports.http.auth import get_current_user, oauth2_scheme +from contextframe.mcp.transports.http.config import HTTPTransportConfig +from contextframe.mcp.transports.http.security import RateLimiter +from contextframe.mcp.transports.http.sse import SSEManager +from contextlib import asynccontextmanager +from fastapi import Depends, FastAPI, HTTPException, Request, Response +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from sse_starlette.sse import EventSourceResponse +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +class MCPHTTPServer: + """HTTP server for MCP protocol using FastAPI.""" + + def __init__( + self, + dataset: FrameDataset, + config: HTTPTransportConfig | None = None, + ): + self.dataset = dataset + self.config = config or HTTPTransportConfig() + + # Validate configuration + errors = self.config.validate() + if errors: + raise ValueError( + f"Invalid HTTP transport configuration: {'; '.join(errors)}" + ) + + # Create components + self.adapter = HttpAdapter() + self.handler = MessageHandler(dataset, self.adapter) + self.sse_manager = SSEManager( + max_connections=self.config.sse_max_connections, + max_age_seconds=self.config.sse_max_age_seconds, + ) + self.rate_limiter = RateLimiter( + requests_per_minute=self.config.rate_limit_requests_per_minute, + burst=self.config.rate_limit_burst, + ) + + # Create FastAPI app + self.app = self._create_app() + + def _create_app(self) -> FastAPI: + """Create and configure FastAPI application.""" + + @asynccontextmanager + async def lifespan(app: FastAPI): + # Startup + await self.adapter.initialize() + await self.sse_manager.start() + logger.info( + f"MCP HTTP server started on {self.config.host}:{self.config.port}" + ) + if self.config.ssl_enabled: + logger.info("SSL/TLS enabled") + if self.config.auth_enabled: + logger.info("Authentication enabled") + yield + # Shutdown + await self.sse_manager.stop() + await self.adapter.shutdown() + logger.info("MCP HTTP server stopped") + + app = FastAPI( + title="ContextFrame MCP Server", + description="Model Context Protocol server for ContextFrame", + version="1.0.0", + lifespan=lifespan, + ) + + # Add CORS middleware if enabled + if self.config.cors_enabled: + app.add_middleware( + CORSMiddleware, + allow_origins=self.config.cors_origins, + allow_credentials=self.config.cors_credentials, + allow_methods=self.config.cors_methods, + allow_headers=self.config.cors_headers, + ) + + # Add routes + self._add_routes(app) + + return app + + def _add_routes(self, app: FastAPI) -> None: + """Add all MCP routes to the app.""" + + # Dependency for optional auth + auth_dep = Depends(get_current_user) if self.config.auth_enabled else None + + # Health check endpoints + @app.get("/health") + async def health(): + """Health check endpoint.""" + return {"status": "healthy"} + + @app.get("/ready") + async def ready(): + """Readiness check endpoint.""" + # Check if dataset is accessible + try: + await asyncio.to_thread(lambda: len(self.dataset)) + return {"status": "ready"} + except Exception as e: + raise HTTPException(status_code=503, detail=f"Dataset not ready: {e}") + + # MCP JSON-RPC endpoint + @app.post("/mcp/v1/jsonrpc") + async def jsonrpc_endpoint( + request: JSONRPCRequest, user: dict | None = auth_dep + ): + """Main JSON-RPC endpoint for MCP protocol.""" + # Apply rate limiting if enabled + if self.config.rate_limit_enabled: + if not await self.rate_limiter.check_rate_limit( + user["id"] if user else "anonymous" + ): + return JSONRPCResponse( + id=request.id, + error=JSONRPCError(code=-32000, message="Rate limit exceeded"), + ) + + # Handle the request + try: + response = await self.handler.handle_message(request.model_dump()) + return JSONRPCResponse(**response) + except Exception as e: + logger.error(f"Error handling request: {e}") + return JSONRPCResponse( + id=request.id, + error=JSONRPCError( + code=-32603, message="Internal error", data=str(e) + ), + ) + + # Convenience REST endpoints that wrap JSON-RPC + @app.post("/mcp/v1/initialize") + async def initialize(params: InitializeParams, user: dict | None = auth_dep): + """Initialize MCP session.""" + request = JSONRPCRequest( + method="initialize", params=params.model_dump(), id=1 + ) + return await jsonrpc_endpoint(request, user) + + @app.get("/mcp/v1/tools/list") + async def list_tools(user: dict | None = auth_dep): + """List available tools.""" + request = JSONRPCRequest(method="tools/list", id=1) + return await jsonrpc_endpoint(request, user) + + @app.post("/mcp/v1/tools/call") + async def call_tool(params: ToolCallParams, user: dict | None = auth_dep): + """Call a tool.""" + request = JSONRPCRequest( + method="tools/call", params=params.model_dump(), id=1 + ) + response = await jsonrpc_endpoint(request, user) + + # Check if this is a batch operation that returns an operation_id + if ( + hasattr(response, "result") + and isinstance(response.result, dict) + and "operation_id" in response.result + ): + # Add operation_id to response headers for client convenience + return JSONResponse( + content=response.model_dump(), + headers={"X-Operation-Id": response.result["operation_id"]}, + ) + + return response + + @app.get("/mcp/v1/resources/list") + async def list_resources(user: dict | None = auth_dep): + """List available resources.""" + request = JSONRPCRequest(method="resources/list", id=1) + return await jsonrpc_endpoint(request, user) + + @app.post("/mcp/v1/resources/read") + async def read_resource( + params: ResourceReadParams, user: dict | None = auth_dep + ): + """Read a resource.""" + request = JSONRPCRequest( + method="resources/read", params=params.model_dump(), id=1 + ) + return await jsonrpc_endpoint(request, user) + + # SSE endpoints + @app.get("/mcp/v1/sse/progress/{operation_id}") + async def stream_progress(operation_id: str, user: dict | None = auth_dep): + """Stream progress updates for an operation via SSE.""" + # Create SSE stream + stream = await self.sse_manager.create_stream() + + async def event_generator(): + try: + async for event in self.adapter.stream_operation_progress( + operation_id + ): + yield await stream.send_json( + event, event_type=event.get("type") + ) + finally: + await self.sse_manager.close_stream(stream.client_id) + + return EventSourceResponse(event_generator()) + + @app.get("/mcp/v1/sse/subscribe") + async def subscribe_changes( + resource_type: str = "documents", + filter: str | None = None, + user: dict | None = auth_dep, + ): + """Subscribe to dataset changes via SSE.""" + # Create subscription + from contextframe.mcp.core.transport import Subscription + + subscription = Subscription( + id=str(asyncio.create_task(asyncio.sleep(0)).get_name()), + resource_type=resource_type, + filter=filter, + ) + + # Create SSE stream + stream = await self.sse_manager.create_stream() + + async def event_generator(): + try: + async for change in self.adapter.handle_subscription(subscription): + yield await stream.send_json( + change, event_type=change.get("type") + ) + finally: + await self.sse_manager.close_stream(stream.client_id) + self.adapter.cancel_subscription(subscription.id) + + return EventSourceResponse(event_generator()) + + # Metrics endpoint + @app.get("/metrics") + async def metrics(): + """Prometheus-compatible metrics endpoint.""" + metrics_data = { + "sse_active_connections": self.sse_manager.active_connections, + "rate_limiter_stats": await self.rate_limiter.get_stats(), + "dataset_size": len(self.dataset), + "adapter_subscriptions": len(self.adapter._subscriptions), + } + + # Format as Prometheus text format + lines = [] + for key, value in metrics_data.items(): + lines.append(f"# TYPE {key} gauge") + lines.append(f"{key} {value}") + + return Response(content="\n".join(lines), media_type="text/plain") + + # OpenAPI schema + @app.get("/openapi.json") + async def openapi(): + """Get OpenAPI schema.""" + return app.openapi() + + +async def create_http_server( + dataset_path: str, config: HTTPTransportConfig | None = None, **kwargs +) -> MCPHTTPServer: + """Create and configure an HTTP MCP server. + + Args: + dataset_path: Path to the Lance dataset + host: Host to bind to + port: Port to bind to + **kwargs: Additional server configuration + + Returns: + Configured MCPHTTPServer instance + """ + # Open dataset + dataset = FrameDataset.open(dataset_path) + + # Create server + server = MCPHTTPServer( + dataset=dataset, + config=config, + ) + + return server diff --git a/contextframe/mcp/transports/http/sse.py b/contextframe/mcp/transports/http/sse.py new file mode 100644 index 0000000..088b68b --- /dev/null +++ b/contextframe/mcp/transports/http/sse.py @@ -0,0 +1,300 @@ +"""Server-Sent Events (SSE) implementation for HTTP transport.""" + +import asyncio +import json +import logging +import uuid +from collections.abc import AsyncIterator +from datetime import datetime +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +class SSEStream: + """Manages an SSE connection for streaming updates. + + This class handles the low-level SSE protocol details including + event formatting, keepalives, and connection management. + """ + + def __init__(self, client_id: str | None = None): + self.client_id = client_id or str(uuid.uuid4()) + self.created_at = datetime.now() + self._closed = False + self._event_id = 0 + + async def send_event( + self, + data: Any, + event_type: str | None = None, + event_id: str | None = None, + retry: int | None = None, + ) -> str: + """Send an SSE event. + + Args: + data: The data to send (will be JSON encoded if not string) + event_type: Optional event type + event_id: Optional event ID for reconnection + retry: Optional retry timeout in milliseconds + + Returns: + The formatted SSE event string + """ + if self._closed: + raise RuntimeError("SSE stream is closed") + + # Format data + if not isinstance(data, str): + data = json.dumps(data) + + # Build SSE event + lines = [] + + if event_id is None: + self._event_id += 1 + event_id = str(self._event_id) + + if event_type: + lines.append(f"event: {event_type}") + + lines.append(f"id: {event_id}") + + if retry is not None: + lines.append(f"retry: {retry}") + + # Split data by newlines and format + for line in data.split('\n'): + lines.append(f"data: {line}") + + # SSE events end with double newline + return '\n'.join(lines) + '\n\n' + + async def send_json( + self, data: dict[str, Any], event_type: str | None = None + ) -> str: + """Send JSON data as an SSE event.""" + return await self.send_event(data, event_type=event_type) + + async def send_keepalive(self) -> str: + """Send a keepalive comment to maintain connection.""" + return f": keepalive {datetime.now().isoformat()}\n\n" + + async def stream_progress( + self, operation_id: str, progress_queue: asyncio.Queue + ) -> AsyncIterator[str]: + """Stream progress updates for an operation. + + Yields SSE-formatted events for each progress update. + """ + logger.info(f"Starting progress stream for operation {operation_id}") + + try: + # Send initial event + yield await self.send_json( + { + "operation_id": operation_id, + "status": "started", + "timestamp": datetime.now().isoformat(), + }, + event_type="progress_start", + ) + + # Stream progress updates + while not self._closed: + try: + # Wait for progress with timeout for keepalive + event = await asyncio.wait_for( + progress_queue.get(), + timeout=25.0, # Send keepalive before 30s timeout + ) + + # Send progress event + yield await self.send_json( + event.get("data", event), + event_type=event.get("type", "progress"), + ) + + # Check if operation is complete + if event.get("type") == "complete": + logger.info(f"Operation {operation_id} completed") + break + + except TimeoutError: + # Send keepalive to maintain connection + yield await self.send_keepalive() + + except Exception as e: + logger.error(f"Error in progress stream: {e}") + yield await self.send_json( + {"error": str(e), "operation_id": operation_id}, event_type="error" + ) + + finally: + # Send final event + yield await self.send_json( + { + "operation_id": operation_id, + "status": "closed", + "timestamp": datetime.now().isoformat(), + }, + event_type="progress_end", + ) + + async def stream_changes( + self, subscription_id: str, change_iterator: AsyncIterator[dict[str, Any]] + ) -> AsyncIterator[str]: + """Stream dataset changes for a subscription. + + Yields SSE-formatted events for each change. + """ + logger.info(f"Starting change stream for subscription {subscription_id}") + + try: + # Send initial event + yield await self.send_json( + { + "subscription_id": subscription_id, + "status": "connected", + "timestamp": datetime.now().isoformat(), + }, + event_type="subscription_start", + ) + + # Stream changes + async for change in change_iterator: + if self._closed: + break + + # Send change event + yield await self.send_json( + change, event_type=change.get("type", "change") + ) + + except Exception as e: + logger.error(f"Error in change stream: {e}") + yield await self.send_json( + {"error": str(e), "subscription_id": subscription_id}, + event_type="error", + ) + + finally: + # Send final event + yield await self.send_json( + { + "subscription_id": subscription_id, + "status": "closed", + "timestamp": datetime.now().isoformat(), + }, + event_type="subscription_end", + ) + + async def close(self) -> None: + """Close the SSE stream.""" + self._closed = True + logger.info(f"SSE stream {self.client_id} closed") + + @property + def is_closed(self) -> bool: + """Check if stream is closed.""" + return self._closed + + @property + def age_seconds(self) -> float: + """Get age of stream in seconds.""" + return (datetime.now() - self.created_at).total_seconds() + + +class SSEManager: + """Manages multiple SSE streams and their lifecycle.""" + + def __init__(self, max_connections: int = 1000, max_age_seconds: int = 3600): + self.streams: dict[str, SSEStream] = {} + self.max_connections = max_connections + self.max_age_seconds = max_age_seconds + self._cleanup_task: asyncio.Task | None = None + + async def start(self) -> None: + """Start the SSE manager and cleanup task.""" + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info("SSE manager started") + + async def stop(self) -> None: + """Stop the SSE manager and close all streams.""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + # Close all streams + for stream in list(self.streams.values()): + await stream.close() + + self.streams.clear() + logger.info("SSE manager stopped") + + async def create_stream(self, client_id: str | None = None) -> SSEStream: + """Create a new SSE stream.""" + # Check connection limit + if len(self.streams) >= self.max_connections: + raise RuntimeError(f"Maximum connections ({self.max_connections}) reached") + + stream = SSEStream(client_id) + self.streams[stream.client_id] = stream + logger.info(f"Created SSE stream {stream.client_id}") + return stream + + async def get_stream(self, client_id: str) -> SSEStream | None: + """Get an existing SSE stream.""" + return self.streams.get(client_id) + + async def close_stream(self, client_id: str) -> None: + """Close and remove an SSE stream.""" + if client_id in self.streams: + stream = self.streams[client_id] + await stream.close() + del self.streams[client_id] + logger.info(f"Closed SSE stream {client_id}") + + async def _cleanup_loop(self) -> None: + """Periodically clean up old or closed streams.""" + while True: + try: + await asyncio.sleep(60) # Check every minute + + # Find streams to close + to_close = [] + for client_id, stream in self.streams.items(): + if stream.is_closed or stream.age_seconds > self.max_age_seconds: + to_close.append(client_id) + + # Close old streams + for client_id in to_close: + await self.close_stream(client_id) + + if to_close: + logger.info(f"Cleaned up {len(to_close)} SSE streams") + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in SSE cleanup loop: {e}") + + @property + def active_connections(self) -> int: + """Get number of active connections.""" + return len(self.streams) + + @property + def stats(self) -> dict[str, Any]: + """Get SSE manager statistics.""" + return { + "active_connections": self.active_connections, + "max_connections": self.max_connections, + "oldest_stream_age": max( + (s.age_seconds for s in self.streams.values()), default=0 + ), + } diff --git a/contextframe/tests/test_mcp/test_analytics.py b/contextframe/tests/test_mcp/test_analytics.py new file mode 100644 index 0000000..3bfa53b --- /dev/null +++ b/contextframe/tests/test_mcp/test_analytics.py @@ -0,0 +1,410 @@ +"""Tests for MCP analytics tools.""" + +import asyncio +import pytest +from contextframe.frame import FrameDataset +from contextframe.mcp.analytics.analyzer import ( + QueryAnalyzer, + QueryMetrics, + RelationshipAnalyzer, + UsageAnalyzer, +) +from contextframe.mcp.analytics.optimizer import ( + IndexAdvisor, + PerformanceBenchmark, + StorageOptimizer, +) +from contextframe.mcp.analytics.stats import DatasetStats, StatsCollector +from contextframe.mcp.analytics.tools import ( + AnalyzeUsageHandler, + BenchmarkOperationsHandler, + ExportMetricsHandler, + GetDatasetStatsHandler, + IndexRecommendationsHandler, + OptimizeStorageHandler, + QueryPerformanceHandler, + RelationshipAnalysisHandler, +) +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, Mock, patch + + +class TestStatsCollector: + """Test StatsCollector functionality.""" + + @pytest.fixture + def mock_dataset(self): + """Create a mock dataset.""" + dataset = Mock(spec=FrameDataset) + + # Mock dataset methods + dataset.get_dataset_stats.return_value = { + "dataset_stats": { + "num_fragments": 5, + "num_deleted_rows": 10, + "num_small_files": 2, + }, + "version_info": { + "current_version": 10, + "latest_version": 10, + }, + "storage": { + "num_rows": 1000, + }, + "indices": [ + {"name": "embedding_idx", "type": "vector", "fields": ["embedding"]}, + {"name": "id_idx", "type": "scalar", "fields": ["id"]}, + ], + } + + dataset.get_fragment_stats.return_value = [ + { + "fragment_id": 0, + "num_rows": 200, + "physical_rows": 220, + "num_deletions": 20, + }, + { + "fragment_id": 1, + "num_rows": 800, + "physical_rows": 800, + "num_deletions": 0, + }, + ] + + dataset.count_by_filter.return_value = 25 # Collection count + + # Mock scanner for content stats + mock_batch = Mock() + mock_batch.column_names = ["record_type", "created_at"] + mock_batch.column.return_value.to_pylist.return_value = [ + "document", + "document", + "collection_header", + ] + + mock_scanner = Mock() + mock_scanner.to_batches.return_value = [mock_batch] + dataset.scanner.return_value = mock_scanner + + return dataset + + @pytest.mark.asyncio + async def test_collect_stats(self, mock_dataset): + """Test collecting comprehensive stats.""" + collector = StatsCollector(mock_dataset) + + stats = await collector.collect_stats( + include_content=True, + include_fragments=True, + include_relationships=False, + ) + + assert isinstance(stats, DatasetStats) + assert stats.total_documents == 1000 + assert stats.num_fragments == 5 + assert stats.total_collections == 25 + assert len(stats.indices) == 2 + + @pytest.mark.asyncio + async def test_stats_to_dict(self, mock_dataset): + """Test converting stats to dictionary.""" + collector = StatsCollector(mock_dataset) + stats = await collector.collect_stats() + + stats_dict = stats.to_dict() + + assert "summary" in stats_dict + assert "storage" in stats_dict + assert "versions" in stats_dict + assert "indices" in stats_dict + + +class TestQueryAnalyzer: + """Test QueryAnalyzer functionality.""" + + @pytest.fixture + def analyzer(self): + """Create a query analyzer with mock data.""" + dataset = Mock(spec=FrameDataset) + analyzer = QueryAnalyzer(dataset) + + # Add some test queries + for i in range(20): + metrics = QueryMetrics( + query_type="vector" if i % 2 == 0 else "text", + query_text=f"test query {i}", + duration_ms=50 + i * 10, + rows_scanned=100 + i * 50, + rows_returned=10 + i, + index_used=i % 3 == 0, + timestamp=datetime.now() - timedelta(hours=i), + ) + analyzer.record_query(metrics) + + return analyzer + + @pytest.mark.asyncio + async def test_analyze_performance(self, analyzer): + """Test query performance analysis.""" + analysis = await analyzer.analyze_performance( + time_range=timedelta(days=1), + min_duration_ms=100, + ) + + assert "summary" in analysis + assert "by_type" in analysis + assert "slow_queries" in analysis + + # Check summary stats + summary = analysis["summary"] + assert summary["total_queries"] > 0 + assert "avg_duration_ms" in summary + assert "p90_duration_ms" in summary + + +class TestStorageOptimizer: + """Test StorageOptimizer functionality.""" + + @pytest.fixture + def optimizer(self): + """Create storage optimizer.""" + dataset = Mock(spec=FrameDataset) + + # Mock optimization methods + dataset.compact_files.return_value = { + "fragments_compacted": 3, + "files_removed": 5, + "files_added": 2, + } + + dataset.cleanup_old_versions.return_value = { + "bytes_removed": 1024 * 1024 * 100, # 100MB + "old_versions_removed": 5, + } + + dataset.optimize_indices.return_value = { + "status": "completed", + "indices_optimized": 2, + } + + # Add fragment stats for dry run + dataset.get_fragment_stats.return_value = [ + {"num_rows": 5000, "size_bytes": 1024 * 1024}, + {"num_rows": 15000, "size_bytes": 3 * 1024 * 1024}, + {"num_rows": 3000, "size_bytes": 512 * 1024}, + ] + + # Add version history for vacuum dry run + dataset.get_version_history.return_value = [ + {"version": 1, "timestamp": datetime.now() - timedelta(days=10)}, + {"version": 2, "timestamp": datetime.now() - timedelta(days=5)}, + {"version": 3, "timestamp": datetime.now()}, + ] + + return StorageOptimizer(dataset) + + @pytest.mark.asyncio + async def test_optimize_storage(self, optimizer): + """Test storage optimization.""" + result = await optimizer.optimize_storage( + operations=["compact"], + dry_run=False, + ) + + assert result["operations"] + assert len(result["operations"]) == 1 + + op_result = result["operations"][0] + assert op_result["operation"] == "compact" + assert op_result["success"] + + @pytest.mark.asyncio + async def test_dry_run(self, optimizer): + """Test dry run mode.""" + result = await optimizer.optimize_storage( + operations=["compact", "vacuum"], + dry_run=True, + ) + + assert len(result["operations"]) == 2 + for op in result["operations"]: + assert op["metrics"].get("preview", False) + + +class TestAnalyticsTools: + """Test analytics tool handlers.""" + + @pytest.fixture + def mock_dataset(self): + """Create a comprehensive mock dataset.""" + dataset = Mock(spec=FrameDataset) + + # Set up all necessary mocks + dataset.get_dataset_stats.return_value = { + "dataset_stats": {"num_fragments": 5}, + "version_info": {"current_version": 10}, + "storage": {"num_rows": 1000}, + "indices": [], + } + + dataset.get_fragment_stats.return_value = [] + dataset.count_by_filter.return_value = 0 + dataset.list_indices.return_value = [] + dataset.scanner.return_value = Mock(to_batches=Mock(return_value=[])) + + return dataset + + @pytest.mark.asyncio + async def test_get_dataset_stats_handler(self, mock_dataset): + """Test get_dataset_stats tool handler.""" + handler = GetDatasetStatsHandler(mock_dataset) + + # Test schema + schema = handler.get_input_schema() + assert schema["type"] == "object" + assert "include_details" in schema["properties"] + + # Test execution + result = await handler.execute(include_details=False) + + assert result["success"] + assert "stats" in result + assert "timestamp" in result + + @pytest.mark.asyncio + async def test_analyze_usage_handler(self, mock_dataset): + """Test analyze_usage tool handler.""" + handler = AnalyzeUsageHandler(mock_dataset) + + # Test schema + schema = handler.get_input_schema() + assert "time_range" in schema["properties"] + assert "group_by" in schema["properties"] + + # Test execution + result = await handler.execute(time_range="7d", group_by="day") + + assert result["success"] + assert "analysis" in result + assert result["period"] == "7d" + + @pytest.mark.asyncio + async def test_optimize_storage_handler(self, mock_dataset): + """Test optimize_storage tool handler.""" + # Mock the optimization methods + mock_dataset.compact_files = Mock(return_value={"files_removed": 5}) + mock_dataset.cleanup_old_versions = Mock(return_value={"bytes_removed": 1000}) + + handler = OptimizeStorageHandler(mock_dataset) + + # Test schema + schema = handler.get_input_schema() + assert "operations" in schema["properties"] + assert "dry_run" in schema["properties"] + + # Test execution + result = await handler.execute(operations=["compact"], dry_run=True) + + assert result["success"] + assert result["dry_run"] + assert "results" in result + + @pytest.mark.asyncio + async def test_export_metrics_handler(self, mock_dataset): + """Test export_metrics tool handler.""" + handler = ExportMetricsHandler(mock_dataset) + + # Test different formats + json_result = await handler.execute(format="json") + assert json_result["success"] + assert json_result["format"] == "json" + assert isinstance(json_result["data"], dict) + + prom_result = await handler.execute(format="prometheus") + assert prom_result["success"] + assert prom_result["format"] == "prometheus" + assert isinstance(prom_result["data"], str) + assert "# TYPE" in prom_result["data"] + + csv_result = await handler.execute(format="csv") + assert csv_result["success"] + assert csv_result["format"] == "csv" + assert isinstance(csv_result["data"], str) + + @pytest.mark.asyncio + async def test_benchmark_operations_handler(self, mock_dataset): + """Test benchmark_operations tool handler.""" + # Mock knn_search + mock_dataset.knn_search = Mock(return_value=[]) + + # Mock scanner for sample data + mock_batch = Mock() + mock_batch.column.return_value.to_pylist.return_value = ["id1", "id2"] + mock_scanner = Mock() + mock_scanner.to_batches.return_value = [mock_batch] + mock_dataset.scanner.return_value = mock_scanner + + handler = BenchmarkOperationsHandler(mock_dataset) + + # Test schema validation + schema = handler.get_input_schema() + assert schema["properties"]["operations"]["items"]["enum"] == [ + "search", + "insert", + "update", + "scan", + ] + + # Test execution + result = await handler.execute( + operations=["scan"], + sample_size=10, + concurrency=1, + ) + + assert result["success"] + assert "benchmarks" in result + assert "operations" in result["benchmarks"] + + +class TestIndexAdvisor: + """Test IndexAdvisor functionality.""" + + @pytest.fixture + def advisor(self): + """Create index advisor.""" + dataset = Mock(spec=FrameDataset) + dataset._dataset = Mock() + dataset._dataset.schema = Mock() + + # Mock schema fields + mock_field = Mock() + mock_field.name = "test_field" + mock_field.type = "string" + mock_field.nullable = True + mock_field.metadata = None + + dataset._dataset.schema.__iter__ = Mock(return_value=iter([mock_field])) + dataset.list_indices.return_value = [] + + return IndexAdvisor(dataset) + + @pytest.mark.asyncio + async def test_get_recommendations(self, advisor): + """Test index recommendations.""" + # Record some query patterns + advisor.record_query_pattern("id = '123'", ["id"]) + advisor.record_query_pattern("created_at > '2024-01-01'", ["created_at"]) + + recommendations = await advisor.get_recommendations( + analyze_queries=True, + workload_type="mixed", + ) + + assert "current_indices" in recommendations + assert "recommendations" in recommendations + assert "index_coverage" in recommendations + + # Should recommend indices for frequently queried fields + assert len(recommendations["recommendations"]) > 0 diff --git a/contextframe/tests/test_mcp/test_transports/__init__.py b/contextframe/tests/test_mcp/test_transports/__init__.py new file mode 100644 index 0000000..445b4bc --- /dev/null +++ b/contextframe/tests/test_mcp/test_transports/__init__.py @@ -0,0 +1 @@ +"""Tests for MCP transport implementations.""" diff --git a/contextframe/tests/test_mcp/test_transports/test_http.py b/contextframe/tests/test_mcp/test_transports/test_http.py new file mode 100644 index 0000000..bfad54c --- /dev/null +++ b/contextframe/tests/test_mcp/test_transports/test_http.py @@ -0,0 +1,279 @@ +"""Tests for HTTP transport implementation.""" + +import asyncio +import pytest +from contextframe.mcp.core.transport import Progress, Subscription +from contextframe.mcp.transports.http.adapter import HttpAdapter +from contextframe.mcp.transports.http.config import HTTPTransportConfig +from contextframe.mcp.transports.http.sse import SSEManager, SSEStream +from unittest.mock import AsyncMock, MagicMock + + +class TestHttpAdapter: + """Test HttpAdapter class.""" + + @pytest.fixture + def adapter(self): + """Create an HttpAdapter instance.""" + return HttpAdapter() + + @pytest.mark.asyncio + async def test_initialize_shutdown(self, adapter): + """Test adapter initialization and shutdown.""" + await adapter.initialize() + assert adapter._active_streams == {} + assert adapter._operation_progress == {} + + await adapter.shutdown() + assert len(adapter._active_streams) == 0 + + @pytest.mark.asyncio + async def test_send_progress(self, adapter): + """Test sending progress updates.""" + # Create a mock progress handler + progress_received = [] + + async def handler(progress): + progress_received.append(progress) + + adapter.add_progress_handler(handler) + + # Send progress + progress = Progress( + operation="test_op", + current=1, + total=10, + status="Testing", + details={"operation_id": "test-123"}, + ) + await adapter.send_progress(progress) + + # Check handler was called + assert len(progress_received) == 1 + assert progress_received[0] == progress + + @pytest.mark.asyncio + async def test_create_operation(self, adapter): + """Test creating and completing operations.""" + # Create operation + op_id = await adapter.create_operation("test_operation") + assert op_id in adapter._active_operations + assert op_id in adapter._operation_progress + + # Complete operation + await adapter.complete_operation(op_id) + assert op_id not in adapter._active_operations + + @pytest.mark.asyncio + async def test_stream_operation_progress(self, adapter): + """Test streaming operation progress.""" + # Create operation + op_id = await adapter.create_operation("test_operation") + + # Stream progress in background + async def stream_progress(): + events = [] + async for event in adapter.stream_operation_progress(op_id): + events.append(event) + if event.get("type") == "complete": + break + return events + + # Start streaming + stream_task = asyncio.create_task(stream_progress()) + + # Send some progress + await adapter.send_progress( + Progress( + operation="test_op", + current=5, + total=10, + status="Halfway", + details={"operation_id": op_id}, + ) + ) + + # Complete operation + await adapter.complete_operation(op_id) + + # Get results + events = await stream_task + assert len(events) >= 2 # At least progress and complete events + assert events[-1]["type"] == "complete" + + @pytest.mark.asyncio + async def test_subscription_handling(self, adapter): + """Test subscription creation and notification.""" + # Create subscription + sub = Subscription( + id="test-sub", resource_type="documents", filter='{"type": "test"}' + ) + + # Handle subscription + events = [] + + async def collect_events(): + async for event in adapter.handle_subscription(sub): + events.append(event) + if len(events) >= 2: # Initial + one change + break + + # Start collecting + collect_task = asyncio.create_task(collect_events()) + + # Wait for subscription to be registered + await asyncio.sleep(0.1) + + # Send a change notification + await adapter.notify_change( + {"resource_type": "documents", "type": "test", "change": "created"} + ) + + # Cancel subscription + adapter.cancel_subscription(sub.id) + + # Wait for task to complete + try: + await asyncio.wait_for(collect_task, timeout=1.0) + except TimeoutError: + collect_task.cancel() + + # Check events + assert len(events) >= 1 + assert events[0]["type"] == "subscription_created" + + +class TestSSEStream: + """Test SSEStream class.""" + + @pytest.fixture + def stream(self): + """Create an SSEStream instance.""" + return SSEStream() + + @pytest.mark.asyncio + async def test_send_event(self, stream): + """Test sending SSE events.""" + # Send simple event + result = await stream.send_event("Hello World") + assert "data: Hello World" in result + assert result.endswith("\n\n") + + # Send JSON event + result = await stream.send_json({"key": "value"}, event_type="test") + assert "event: test" in result + assert 'data: {"key": "value"}' in result + + @pytest.mark.asyncio + async def test_send_keepalive(self, stream): + """Test sending keepalive messages.""" + result = await stream.send_keepalive() + assert result.startswith(": keepalive") + assert result.endswith("\n\n") + + @pytest.mark.asyncio + async def test_stream_closed(self, stream): + """Test stream closure.""" + await stream.close() + assert stream.is_closed + + # Should raise error when trying to send after close + with pytest.raises(RuntimeError): + await stream.send_event("test") + + +class TestSSEManager: + """Test SSEManager class.""" + + @pytest.fixture + async def manager(self): + """Create and start an SSEManager instance.""" + manager = SSEManager(max_connections=10) + await manager.start() + yield manager + await manager.stop() + + @pytest.mark.asyncio + async def test_create_stream(self, manager): + """Test creating streams.""" + stream1 = await manager.create_stream() + assert stream1.client_id in manager.streams + + stream2 = await manager.create_stream("custom-id") + assert stream2.client_id == "custom-id" + assert len(manager.streams) == 2 + + @pytest.mark.asyncio + async def test_max_connections(self, manager): + """Test connection limit enforcement.""" + # Create max connections + for i in range(10): + await manager.create_stream(f"stream-{i}") + + # Should raise error on next connection + with pytest.raises(RuntimeError, match="Maximum connections"): + await manager.create_stream() + + @pytest.mark.asyncio + async def test_close_stream(self, manager): + """Test closing streams.""" + stream = await manager.create_stream("test-stream") + assert "test-stream" in manager.streams + + await manager.close_stream("test-stream") + assert "test-stream" not in manager.streams + assert stream.is_closed + + +class TestHTTPTransportConfig: + """Test HTTPTransportConfig class.""" + + def test_default_config(self): + """Test default configuration values.""" + config = HTTPTransportConfig() + assert config.host == "0.0.0.0" + assert config.port == 8080 + assert config.cors_enabled is True + assert config.auth_enabled is False + assert config.rate_limit_enabled is True + + def test_validation(self): + """Test configuration validation.""" + # Valid config + config = HTTPTransportConfig() + errors = config.validate() + assert len(errors) == 0 + + # Invalid port + config = HTTPTransportConfig(port=70000) + errors = config.validate() + assert any("port" in e for e in errors) + + # Auth without secret + config = HTTPTransportConfig(auth_enabled=True) + errors = config.validate() + assert any("secret key" in e for e in errors) + + def test_to_dict(self): + """Test converting config to dictionary.""" + config = HTTPTransportConfig( + host="localhost", port=8000, auth_enabled=True, auth_secret_key="secret" + ) + + data = config.to_dict() + assert data["server"]["host"] == "localhost" + assert data["server"]["port"] == 8000 + assert data["auth"]["enabled"] is True + + def test_from_env(self, monkeypatch): + """Test loading config from environment variables.""" + monkeypatch.setenv("MCP_HTTP_HOST", "127.0.0.1") + monkeypatch.setenv("MCP_HTTP_PORT", "9000") + monkeypatch.setenv("MCP_HTTP_AUTH_ENABLED", "true") + monkeypatch.setenv("MCP_HTTP_AUTH_SECRET_KEY", "env-secret") + + config = HTTPTransportConfig.from_env() + assert config.host == "127.0.0.1" + assert config.port == 9000 + assert config.auth_enabled is True + assert config.auth_secret_key == "env-secret"