Skip to content

Commit

Permalink
Fix table metadata tool to use resource properly, prompt to use db id…
Browse files Browse the repository at this point in the history
…entifier, remove simple table query, use the new server proxy
  • Loading branch information
davidesner committed Feb 9, 2025
1 parent 5a4f8b7 commit 86e4854
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 154 deletions.
87 changes: 27 additions & 60 deletions src/keboola_mcp_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,71 +13,36 @@ class Config:
"""Server configuration."""

storage_token: str
query_proxy_url: str
storage_api_url: str = "https://connection.keboola.com"
log_level: str = "INFO"
# Add Snowflake credentials
snowflake_account: Optional[str] = None
snowflake_user: Optional[str] = None
snowflake_password: Optional[str] = None
snowflake_warehouse: Optional[str] = None
snowflake_database: Optional[str] = None
snowflake_schema: Optional[str] = None
snowflake_role: Optional[str] = None
_log_level: str = "INFO"

def __init__(
self,
storage_token: str,
storage_api_url: str = "https://connection.keboola.com",
snowflake_account: Optional[str] = None,
snowflake_user: Optional[str] = None,
snowflake_password: Optional[str] = None,
snowflake_warehouse: Optional[str] = None,
snowflake_database: Optional[str] = None,
snowflake_role: Optional[str] = None,
snowflake_schema: Optional[str] = None,
log_level: str = "INFO",
query_proxy_url: str,
storage_api_url: str,
):
self.storage_token = storage_token
self.query_proxy_url = query_proxy_url
self.storage_api_url = storage_api_url
self.snowflake_account = snowflake_account
self.snowflake_user = snowflake_user
self.snowflake_password = snowflake_password
self.snowflake_warehouse = snowflake_warehouse
self.snowflake_database = snowflake_database
self.snowflake_role = snowflake_role
self.snowflake_schema = snowflake_schema
self.log_level = log_level
self._log_level = os.getenv("KBC_LOG_LEVEL", "INFO")

@classmethod
def from_env(cls) -> "Config":
"""Create config from environment variables."""
# Add debug logging using logger instead of print
for env_var in [
"KBC_SNOWFLAKE_ACCOUNT",
"KBC_SNOWFLAKE_USER",
"KBC_SNOWFLAKE_PASSWORD",
"KBC_SNOWFLAKE_WAREHOUSE",
"KBC_SNOWFLAKE_DATABASE",
"KBC_SNOWFLAKE_ROLE",
"KBC_SNOWFLAKE_SCHEMA",
]:
logger.debug(f"Reading {env_var}: {'set' if os.getenv(env_var) else 'not set'}")

"""Create configuration from environment variables."""
storage_token = os.getenv("KBC_STORAGE_TOKEN")
if not storage_token:
raise ValueError("KBC_STORAGE_TOKEN environment variable is required")

query_proxy_url = os.getenv("KBC_QUERY_PROXY_URL")
if not query_proxy_url:
raise ValueError("KBC_QUERY_PROXY_URL environment variable is required")

return cls(
storage_token=storage_token,
query_proxy_url=query_proxy_url,
storage_api_url=os.getenv("KBC_STORAGE_API_URL", "https://connection.keboola.com"),
snowflake_account=os.getenv("KBC_SNOWFLAKE_ACCOUNT"),
snowflake_user=os.getenv("KBC_SNOWFLAKE_USER"),
snowflake_password=os.getenv("KBC_SNOWFLAKE_PASSWORD"),
snowflake_warehouse=os.getenv("KBC_SNOWFLAKE_WAREHOUSE"),
snowflake_database=os.getenv("KBC_SNOWFLAKE_DATABASE"),
snowflake_role=os.getenv("KBC_SNOWFLAKE_ROLE"),
snowflake_schema=os.getenv("KBC_SNOWFLAKE_SCHEMA"),
log_level=os.getenv("KBC_LOG_LEVEL", "INFO"),
)

def validate(self) -> None:
Expand All @@ -86,17 +51,19 @@ def validate(self) -> None:
raise ValueError("Storage token not configured")
if not self.storage_api_url:
raise ValueError("Storage API URL is required")
if self.log_level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise ValueError(f"Invalid log level: {self.log_level}")
if not self.query_proxy_url:
raise ValueError("Query proxy URL is required")
if self._log_level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise ValueError(f"Invalid log level: {self._log_level}")

def has_snowflake_config(self) -> bool:
"""Check if Snowflake configuration is complete."""
return all(
[
self.snowflake_account,
self.snowflake_user,
self.snowflake_password,
self.snowflake_warehouse,
self.snowflake_database,
]
)
@property
def log_level(self) -> str:
"""Get the configured log level."""
return self._log_level

@log_level.setter
def log_level(self, value: str) -> None:
"""Set the log level."""
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise ValueError(f"Invalid log level: {value}")
self._log_level = value
207 changes: 113 additions & 94 deletions src/keboola_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pandas as pd
from mcp.server.fastmcp import FastMCP
import httpx

from .client import KeboolaClient
from .config import Config
Expand Down Expand Up @@ -110,7 +111,7 @@ async def list_components() -> str:
)
async def get_table_detail(table_id: str) -> TableDetail:
"""Get detailed information about a table."""
table = await get_table_metadata(table_id)
table = cast(Dict[str, Any], keboola.storage_client.tables.detail(table_id))

# Get column info
columns = table.get("columns", [])
Expand All @@ -127,108 +128,115 @@ async def get_table_detail(table_id: str) -> TableDetail:
"columns": columns,
"column_identifiers": column_info,
"db_identifier": db_path_manager.get_table_db_path(table),
"schema_identifier": table["id"].split(".")[0],
"table_identifier": table["id"].split(".")[1],
}

@mcp.tool()
async def query_table_data(
table_id: str,
columns: Optional[List[str]] = None,
where: Optional[str] = None,
limit: Optional[int] = None,
) -> str:
"""
Query table data using database identifiers with proper formatting. This method safely constructs
and executes SQL queries by handling database identifiers and query parameters.
Parameters:
table_id (str): The table identifier in format 'bucket.table_name' (e.g., 'in.c-fraudDetection.test_identify')
columns (List[str], optional): List of column names to select. If None, selects all columns (*)
where (str, optional): WHERE clause conditions without the 'WHERE' keyword
limit (int, optional): Maximum number of rows to return
Returns:
str: Query results in string format
Examples:
# Select all columns with limit
query_table_data('in.c-fraudDetection.test_identify', limit=5)
# Select specific columns with condition
query_table_data(
'in.c-fraudDetection.test_identify',
columns=['TransactionID', 'DeviceType'],
where="DeviceType = 'mobile'",
limit=10
)
Note:
This method is preferred over direct SQL queries as it:
- Automatically handles proper database identifiers
- Prevents SQL injection through proper parameter handling
- Uses configuration for database name
- Provides a simpler interface for common query patterns
"""
table_info = await get_table_detail(table_id)

if columns:
column_map = {
col["name"]: col["db_identifier"] for col in table_info["column_identifiers"]
}
select_clause = ", ".join(column_map[col] for col in columns)
else:
select_clause = "*"

query = f"SELECT {select_clause} FROM {table_info['db_identifier']}"

if where:
query += f" WHERE {where}"

if limit:
query += f" LIMIT {limit}"

result: str = await query_table(query)
return result
# @mcp.tool()
# async def query_table_data(
# table_id: str,
# columns: Optional[List[str]] = None,
# where: Optional[str] = None,
# limit: Optional[int] = None,
# ) -> str:
# """
# Query table data using database identifiers with proper formatting. This method safely constructs
# and executes SQL queries by handling database identifiers and query parameters.

# Parameters:
# table_id (str): The table identifier in format 'bucket.table_name' (e.g., 'in.c-fraudDetection.test_identify')
# columns (List[str], optional): List of column names to select. If None, selects all columns (*)
# where (str, optional): WHERE clause conditions without the 'WHERE' keyword
# limit (int, optional): Maximum number of rows to return

# Returns:
# str: Query results in string format

# Examples:
# # Select all columns with limit
# query_table_data('in.c-fraudDetection.test_identify', limit=5)

# # Select specific columns with condition
# query_table_data(
# 'in.c-fraudDetection.test_identify',
# columns=['TransactionID', 'DeviceType'],
# where="DeviceType = 'mobile'",
# limit=10
# )

# Note:
# This method is preferred over direct SQL queries as it:
# - Automatically handles proper database identifiers
# - Prevents SQL injection through proper parameter handling
# - Uses configuration for database name
# - Provides a simpler interface for common query patterns
# """
# table_info = await get_table_detail(table_id)

# if columns:
# column_map = {
# col["name"]: col["db_identifier"] for col in table_info["column_identifiers"]
# }
# select_clause = ", ".join(column_map[col] for col in columns)
# else:
# select_clause = "*"

# query = f"SELECT {select_clause} FROM {table_info['db_identifier']}"

# if where:
# query += f" WHERE {where}"

# if limit:
# query += f" LIMIT {limit}"

# result: str = await query_table(query)
# return result

@mcp.tool()
async def query_table(sql_query: str) -> str:
"""
Execute a Snowflake SQL query to get data from the Storage. Before forming the query always check the
get_table_metadata tool to get the correct database name and table name.
Execute a SQL query through the proxy service to get data from Storage.
Before forming the query always check the get_table_metadata tool to get
the correct database name and table name.
- The {{db_identifier}} is available in the tool response.
Note: SQL queries must include the full path including database name, e.g.:
'SELECT * FROM SAPI_10025."in.c-fraudDetection"."test_identify"'. Snowflake is case sensitive so always
'SELECT * FROM {{db_identifier}}."test_identify"'. Snowflake is case sensitive so always
wrap the column names in double quotes.
"""
conn = None
cursor = None

try:
conn = create_snowflake_connection(config)
cursor = conn.cursor()
cursor.execute(sql_query)
result = cursor.fetchall()
columns = [col[0] for col in cursor.description]

# Convert to CSV
output = StringIO()
writer = csv.writer(output)
writer.writerow(columns)
writer.writerows(result)

return output.getvalue()

except snowflake.connector.errors.ProgrammingError as e:
raise ValueError(f"Snowflake query error: {str(e)}")

async with httpx.AsyncClient() as client:
response = await client.post(
config.query_proxy_url,
json={"query": sql_query},
headers={
"X-StorageApi-Token": config.storage_token,
"Content-Type": "application/json"
}
)

if response.status_code != 200:
raise ValueError(f"Query proxy error: {response.text}")

# Parse the JSON response
data = response.json()

# Extract columns and rows from the result
columns = data["result"]["columns"]
rows = data["result"]["rows"]

# Convert to CSV format
output = StringIO()
writer = csv.writer(output)
writer.writerow(columns)
writer.writerows(rows)

return output.getvalue()

except httpx.HTTPError as e:
raise ValueError(f"HTTP error during query execution: {str(e)}")
except Exception as e:
raise ValueError(
f"Unexpected error during query execution: {str(e)}")

finally:
if cursor:
cursor.close()
if conn:
conn.close()
raise ValueError(f"Unexpected error during query execution: {str(e)}")

# Tools
@mcp.tool()
Expand Down Expand Up @@ -271,9 +279,20 @@ async def get_bucket_metadata(bucket_id: str) -> str:
async def get_table_metadata(table_id: str) -> Dict[str, Any]:
"""Get detailed information about a specific table including its DB identifier and column information."""
# Get table details directly from the storage client
table = cast(Dict[str, Any],
keboola.storage_client.tables.detail(table_id))
return table
table = await get_table_detail(table_id)
return (
f"Table Information:\n"
f"ID: {table['id']}\n"
f"Name: {table['name']}\n"
f"Primary Key: {', '.join(table['primary_key']) if table['primary_key'] else 'None'}\n"
f"Created: {table['created']}\n"
f"Row Count: {table['row_count']}\n"
f"Data Size: {table['data_size_bytes']} bytes\n"
f"Columns: {', '.join(table['columns'])}\n"
f"Database Identifier: {table['db_identifier']}\n"
f"Schema: {table['schema_identifier']}\n"
f"Table: {table['table_identifier']}"
)

@mcp.tool()
async def list_component_configs(component_id: str) -> str:
Expand Down

0 comments on commit 86e4854

Please sign in to comment.