Skip to content

Commit

Permalink
Added dedicated database.py file including implementation and separat…
Browse files Browse the repository at this point in the history
…ion of the DB logic from the sever.py

There is a new class including connection_management that uses pattern matching to check various connection DB_name and verify which works. This makes a more robust solution in case there are various DB options inside of the project - the tool just tries a valid connection string and keep (cache) the connection that worked.
  • Loading branch information
radektomasek committed Jan 31, 2025
1 parent 28f31c8 commit 916efbf
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 68 deletions.
2 changes: 1 addition & 1 deletion src/keboola_mcp_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ def main(args: Optional[List[str]] = None) -> None:


if __name__ == "__main__":
asyncio.run(main())
main()
210 changes: 210 additions & 0 deletions src/keboola_mcp_server/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""Database connection management for Keboola MCP server."""

import logging
from contextlib import contextmanager

import snowflake.connector
from dataclasses import dataclass
from typing import Optional, List, Tuple, Dict, Any

from .config import Config

logger = logging.getLogger(__name__)

class DatabaseConnectionError(Exception):
"""Raised when database connection fails after trying all available patterns."""
pass

@dataclass
class ConnectionPattern:
"""Represents a database connection pattern to try."""
name: str
get_database: callable
required_config: List[str] = None

def __post_init__(self):
if self.required_config is None:
self.required_config = []

class DatabasePathManager:
"""Manages database paths and connections for Keboola tables."""

def __init__(self, config, connection_manager):
self.config = config
self.connection_manager = connection_manager
self._db_path_cache = {}

def get_current_db(self) -> str:
"""
Get the current database path, using the connection manager's patterns.
Returns:
str: The database path to use
"""
try:
# Try to get a working DB connection from our patterns
return self.connection_manager.find_working_connection()
except Exception as e:
logger.warning(f"Failed to get working connection, falling back to token-based path: {e}")
return f"KEBOOLA_{self.config.storage_token.split('-')[0]}"

def get_table_db_path(self, table: Dict[str, Any]) -> str:
"""
Get the database path for a specific table.
Args:
table: Dictionary containing table information
Returns:
str: The full database path for the table
"""
table_id = table["id"]
if table_id in self._db_path_cache:
return self._db_path_cache[table_id]

db_path = self.get_current_db()
table_name = table["name"]
table_path = table["id"]

if table.get("sourceTable"):
db_path = f"KEBOOLA_{table['sourceTable']['project']['id']}"
table_path = table["sourceTable"]["id"]

table_identifier = f'"{db_path}"."{".".join(table_path.split(".")[:-1])}"."{table_name}"'

self._db_path_cache[table_id] = table_identifier
return table_identifier


class ConnectionManager:
"""Manages database connections and connection string patterns for Keboola."""

def __init__(self, config):
self.config = config
self.patterns = self._initialize_patterns()

def _initialize_patterns(self) -> List[ConnectionPattern]:
"""Initialize the list of connection patterns to try."""
return [
ConnectionPattern(
name="Configured Environment",
get_database=lambda: self.config.snowflake_database,
required_config=['snowflake_database']
),
ConnectionPattern(
name="KEBOOLA Token Pattern",
get_database=lambda: f"KEBOOLA_{self.config.storage_token.split('-')[0]}",
required_config=['storage_token']
)
]

def _validate_config_for_pattern(self, pattern: ConnectionPattern) -> bool:
"""Check if all required configuration is present for a pattern."""
return all(
hasattr(self.config, attr) and getattr(self.config, attr)
for attr in pattern.required_config
)

@contextmanager
def _create_test_connection(self, database: str):
"""Create a test connection with the given database name."""
if not self.config.has_snowflake_config():
raise ValueError("Snowflake credentials are not fully configured")

conn = snowflake.connector.connect(
account=self.config.snowflake_account,
user=self.config.snowflake_user,
password=self.config.snowflake_password,
warehouse=self.config.snowflake_warehouse,
database=database,
schema=self.config.snowflake_schema,
role=self.config.snowflake_role,
)

try:
yield conn
finally:
conn.close()


def _test_connection(self, database: str) -> bool:
"""Test if connection works with a simple query."""
try:
with self._create_test_connection(database) as conn:
cur = conn.cursor()
cur.execute('SELECT 1')
return True
except Exception as e:
logger.debug(f"Connection test failed for {database}: {str(e)}")
return False

def find_working_connection(self) -> str:
"""
Try different connection patterns and return the first working one.
Returns:
str: Working database name
Raises:
DatabaseConnectionError: If no working connection pattern is found
"""
results = []

for pattern in self.patterns:
try:
if not self._validate_config_for_pattern(pattern):
results.append((pattern.name, "N/A", "Missing required configuration"))
continue

database = pattern.get_database()
logger.debug(f"Testing ${pattern.name}: {database}")

if self._test_connection(database):
logger.info(f"Successfully connected using {pattern.name}: {database}")
return database

results.append((pattern.name, database, "Connection test failed"))
except Exception as e:
results.append((pattern.name, "N/A", str(e)))
continue

error_msg = "No working connection pattern found:\n"
for pattern, db, error in results:
error_msg += f" - {pattern} ({db}): {error}\n"
raise DatabaseConnectionError(error_msg)


def create_snowflake_connection(config: Config) -> snowflake.connector.connection:
"""Create a return a Snowflake connection using configured credentials.
Args:
config: Configuration object containing Snowflake credentials
Returns:
snowflake.connector.connection: established Snowflake connection
Raises:
ValueError: If credentials are not fully configured or connection fails
"""
if not config.has_snowflake_config():
raise ValueError("Snowflake credentials are not fully configured")

try:
connection_manager = ConnectionManager(config)
database = connection_manager.find_working_connection()

conn = snowflake.connector.connect(
account=config.snowflake_account,
user=config.snowflake_user,
password=config.snowflake_password,
warehouse=config.snowflake_warehouse,
database=database,
schema=config.snowflake_schema,
role=config.snowflake_role,
)

return conn
except DatabaseConnectionError as e:
raise ValueError(f"Failed to find working database connection: {str(e)}")
except Exception as e:
raise ValueError(f"Failed to create Snowflake connection: {str(e)}")
72 changes: 5 additions & 67 deletions src/keboola_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@
from typing import Any, AsyncGenerator, Dict, List, Optional, TypedDict, cast

import pandas as pd
import snowflake.connector
from mcp.server.fastmcp import FastMCP
from snowflake.connector.connection import SnowflakeConnection
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col

from .client import KeboolaClient
from .config import Config
from .database import create_snowflake_connection, ConnectionManager, DatabasePathManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,38 +44,6 @@ class TableDetail(TypedDict):
column_identifiers: List[TableColumnInfo]
db_identifier: str


def create_snowflake_connection(config: Config) -> snowflake.connector.connection:
"""Create a return a Snowflake connection using configured credentials.
Args:
config: Configuration object containing Snowflake credentials
Returns:
snowflake.connector.connection: established Snowflake connection
Raises:
ValueError: If credentials are not fully configured or connection fails
"""
if not config.has_snowflake_config():
raise ValueError("Snowflake credentials are not fully configured")

try:
conn = snowflake.connector.connect(
account=config.snowflake_account,
user=config.snowflake_user,
password=config.snowflake_password,
warehouse=config.snowflake_warehouse,
database=config.snowflake_database,
schema=config.snowflake_schema,
role=config.snowflake_role,
)

return conn
except Exception as e:
raise ValueError(f"Failed to create Snowflake connection: {str(e)}")


def create_server(config: Optional[Config] = None) -> FastMCP:
"""Create and configure the MCP server.
Expand All @@ -103,6 +68,9 @@ def create_server(config: Optional[Config] = None) -> FastMCP:
"Keboola Explorer", dependencies=["keboola.storage-api-client", "httpx", "pandas"]
)

connection_manager = ConnectionManager(config)
db_path_manager = DatabasePathManager(config, connection_manager)

# Create Keboola client instance
try:
keboola = KeboolaClient(config.storage_token, config.storage_api_url)
Expand All @@ -111,36 +79,6 @@ def create_server(config: Optional[Config] = None) -> FastMCP:
raise
logger.info("Successfully initialized Keboola client")

async def get_table_db_path(table: dict) -> str:
"""Get the database path for a specific table."""

db_path = await get_current_db()
table_name = table["name"]
table_path = table["id"]
if table.get("sourceTable"):
db_path = f"KEBOOLA_{table['sourceTable']['project']['id']}"
table_path = table["sourceTable"]["id"]

table_identifier = f'"{db_path}"."{".".join(table_path.split(".")[:-1])}"."{table_name}"'
return table_identifier

async def get_current_db() -> str:
"""
Get the current database.
Returns:
str: The database name
Raises:
ValueError: If database name is not configured or not a string
"""
db_name = config.snowflake_database
if db_name is None:
raise ValueError("Database name is not configured")
if not isinstance(db_name, str):
raise ValueError(f"Database name must be a string, got {type(db_name)}")
return db_name

# Resources
@mcp.resource("keboola://buckets")
async def list_buckets() -> List[BucketInfo]:
Expand Down Expand Up @@ -184,7 +122,7 @@ async def get_table_detail(table_id: str) -> TableDetail:
"data_size_bytes": table.get("dataSizeBytes", 0),
"columns": columns,
"column_identifiers": column_info,
"db_identifier": await get_table_db_path(table),
"db_identifier": db_path_manager.get_table_db_path(table),
}

@mcp.tool()
Expand Down

0 comments on commit 916efbf

Please sign in to comment.