Skip to content

Commit

Permalink
Refactor Snowflake query and table detail methods
Browse files Browse the repository at this point in the history
- Replaced Snowpark session with direct Snowflake connector for more flexible query handling
- Implemented async Snowflake connection context manager
- Improved table detail and query methods to use JSON parsing and CSV output
- Added error handling and connection management for Snowflake interactions
  • Loading branch information
jordanrburger committed Jan 27, 2025
1 parent 82f86b2 commit d792774
Showing 1 changed file with 66 additions and 67 deletions.
133 changes: 66 additions & 67 deletions src/keboola_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@
import logging
import os
import tempfile
from typing import Any, Dict, List, Optional, TypedDict, cast
import json
import csv
from typing import Any, Dict, List, Optional, TypedDict, cast, AsyncGenerator
from io import StringIO
from contextlib import asynccontextmanager

import pandas as pd
from mcp.server.fastmcp import FastMCP
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
import snowflake.connector
from snowflake.connector.connection import SnowflakeConnection

from .client import KeboolaClient
from .config import Config
Expand Down Expand Up @@ -120,11 +126,18 @@ async def list_components() -> str:
description="Get detailed information about a Keboola table including its DB identifier and column information",
)
async def get_table_detail(table_id: str) -> TableDetail:
"""Get structured information about a specific table including its DB identifier."""
table = cast(Dict[str, Any], keboola.storage_client.tables.detail(table_id))
"""Get detailed information about a table."""
table = await get_table_metadata(table_id)
table = json.loads(table)

# Get column info
columns = table.get("columns", [])
column_info = [{"name": col, "db_identifier": f'"{col}"'} for col in columns]
column_info = [
TableColumnInfo(
name=col,
db_identifier=f'"{col}"'
) for col in columns
]

return {
"id": table["id"],
Expand All @@ -145,14 +158,7 @@ async def query_table_data(
where: Optional[str] = None,
limit: Optional[int] = None,
) -> str:
"""Query table data using proper DB identifiers.
Args:
table_id: The table ID in Keboola
columns: List of column names to select. If None, selects all columns
where: WHERE clause without the 'WHERE' keyword
limit: Maximum number of rows to return
"""
"""Query table data using proper DB identifiers."""
table_info = await get_table_detail(table_id)

# Build column list with proper identifiers
Expand All @@ -172,64 +178,29 @@ async def query_table_data(
if limit:
query += f" LIMIT {limit}"

return await query_table(query)
result: str = await query_table(query)
return result

@mcp.tool()
async def query_table(sql_query: str) -> str:
"""Execute a Snowflake SQL query to get data from the Storage.
Args:
sql_query: SQL query to execute (Snowflake syntax). All database identifiers
(table names, column names, schema names) must be enclosed in
double quotes, e.g. SELECT "column" FROM "database"."schema"."table";
Returns:
Query results as formatted string
"""
try:

if not all(
[
config.snowflake_account,
config.snowflake_user,
config.snowflake_password,
config.snowflake_warehouse,
config.snowflake_database,
config.snowflake_role,
]
):
return "Snowflake credentials not fully configured in environment variables"

# Create Snowpark session using config
connection_parameters = {
"account": config.snowflake_account,
"user": config.snowflake_user,
"password": config.snowflake_password,
"warehouse": config.snowflake_warehouse,
"database": config.snowflake_database,
"role": config.snowflake_role,
}

session = Session.builder.configs(connection_parameters).create()

try:
# Execute query
result = session.sql(sql_query).collect()

if not result:
return "Query returned no results"

# Convert results to pandas for consistent output formatting
df = pd.DataFrame(result)
return f"Query results ({len(df)} rows):\n\n" + df.to_string(index=False)

except Exception as e:
return f"Error executing query: {str(e)}"
finally:
session.close()

except Exception as e:
return f"Error setting up Snowflake connection: {str(e)}"
"""Execute a Snowflake SQL query to get data from the Storage."""
# Get current database
db = await get_current_db()

# Execute query
async with snowflake_connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"USE DATABASE {db}")
await cursor.execute(sql_query)
result = await cursor.fetchall()
columns = [col[0] for col in cursor.description]

# Convert to CSV
output = StringIO()
writer = csv.writer(output)
writer.writerow(columns)
writer.writerows(result)
return output.getvalue()

# Tools
@mcp.tool()
Expand Down Expand Up @@ -322,4 +293,32 @@ async def list_bucket_tables_tool(bucket_id: str) -> str:
for table in tables
)

@asynccontextmanager
async def snowflake_connection() -> AsyncGenerator[SnowflakeConnection, None]:
"""Create a Snowflake connection."""
if not all(
[
config.snowflake_account,
config.snowflake_user,
config.snowflake_password,
config.snowflake_warehouse,
config.snowflake_database,
config.snowflake_role,
]
):
raise ValueError("Snowflake credentials not fully configured in environment variables")

conn: SnowflakeConnection = snowflake.connector.connect(
account=config.snowflake_account,
user=config.snowflake_user,
password=config.snowflake_password,
warehouse=config.snowflake_warehouse,
database=config.snowflake_database,
role=config.snowflake_role,
)
try:
yield conn
finally:
conn.close()

return mcp

0 comments on commit d792774

Please sign in to comment.