Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/workflows/build-plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,20 @@ jobs:
working-directory: ${{ inputs.plugin_directory }}
run: hatch -v env create hatch-test.py3.13

- name: Install ODBC Driver 18 for SQL Server
if: inputs.plugin_directory == 'sam-sql-database-tool'
run: |
curl https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor | sudo tee /usr/share/keyrings/microsoft-prod.gpg > /dev/null
curl https://packages.microsoft.com/config/ubuntu/$(lsb_release -rs)/prod.list | sudo tee /etc/apt/sources.list.d/mssql-release.list
sudo apt-get update
sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18 unixodbc-dev

- name: Install FreeTDS ODBC Driver
if: inputs.plugin_directory == 'sam-sql-database-tool'
run: |
sudo apt-get install -y freetds-dev tdsodbc unixodbc-dev
sudo odbcinst -i -d -f /usr/share/tdsodbc/odbcinst.ini

- name: Test Plugin ${{ inputs.plugin_directory }} with Python 3.13
id: test
continue-on-error: true
Expand Down
14 changes: 12 additions & 2 deletions sam-sql-database-tool/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Unlike the `sam-sql-database` agent, which provides a complete Natural-Language-
## Key Features

- **Dynamic Tool Creation**: Define custom SQL query tools directly in your agent's YAML configuration. Each tool instance is completely independent.
- **Multi-Database Support**: Natively supports PostgreSQL, MySQL, and MariaDB.
- **Multi-Database Support**: Natively supports PostgreSQL, MySQL, MariaDB, and MSSQL.
- **Dedicated Connections**: Each tool instance creates its own dedicated database connection, allowing for fine-grained configuration.
- **Flexible Schema Handling**:
- Automatic schema detection and summarization for LLM prompting.
Expand Down Expand Up @@ -54,7 +54,17 @@ tools:

- `tool_name`: (Required) The function name the LLM will use to call the tool.
- `tool_description`: (Optional) A clear description for the LLM explaining what the tool does.
- `connection_string`: (Required) The full database connection string (e.g., `postgresql+psycopg2://user:password@host:port/dbname` for PostgreSQL, or `mysql+pymysql://user:password@host:port/dbname` for MySQL/MariaDB). It is highly recommended to use a single environment variable for the entire string.
- `connection_string`: (Required) The full database connection string. It is highly recommended to use a single environment variable for the entire string. Supported formats:
- **PostgreSQL**: `postgresql+psycopg2://user:password@host:port/dbname`
- **MySQL**: `mysql+pymysql://user:password@host:port/dbname`
- **MariaDB**: `mysql+pymysql://user:password@host:port/dbname`
- **MSSQL (FreeTDS - Recommended)**: `mssql+pyodbc://user:password@host:port/dbname?driver=FreeTDS&TrustServerCertificate=yes`
- Open-source driver with simpler installation: `sudo apt-get install freetds-dev freetds-bin tdsodbc && sudo odbcinst -i -d -f /usr/share/tdsodbc/odbcinst.ini`
- Works well for standard SQL operations.
- **MSSQL (Microsoft ODBC)**: `mssql+pyodbc://user:password@host:port/dbname?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes`
- Official Microsoft driver with full feature support (Azure AD auth, Always Encrypted, etc.).
- Requires ODBC Driver 17 or 18 installed on the host system.
- Use `TrustServerCertificate=yes` for self-signed certificates or `Encrypt=no` to disable encryption.
- `auto_detect_schema`: (Optional, default: `true`) If `true`, the plugin attempts to automatically detect the database schema. If `false`, you must provide `schema_summary_override`.
- `schema_summary_override`: (Required if `auto_detect_schema` is `false`) A concise natural language summary of the schema, suitable for direct inclusion in an LLM prompt.
- `max_enum_cardinality`: (Optional, default: `100`) Maximum number of distinct values to consider a column as an enum. Increase for columns like countries (190+), decrease for faster init times.
Expand Down
5 changes: 3 additions & 2 deletions sam-sql-database-tool/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ type = "tool"

[project]
name = "sam_sql_database_tool"
version = "0.1.1"
version = "0.2.0"
description = "A reusable SQL database tool for Solace Agent Mesh."
readme = "README.md"
authors = [
Expand All @@ -21,6 +21,7 @@ dependencies = [
"SQLAlchemy==2.0.40",
"PyMySQL==1.1.2",
"psycopg2-binary==2.9.10",
"pyodbc==5.3.0",
"pydantic==2.11.9",
"PyYAML==6.0.2"
]
Expand Down Expand Up @@ -50,5 +51,5 @@ dependencies = [
"solace-agent-mesh>=1.4.9,<2.0.0",
"testing.postgresql",
"psycopg2-binary",
"testcontainers[mysql,postgres]>=4.4.0"
"testcontainers[mysql,postgres,mssql]>=4.4.0"
]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from sqlalchemy.engine import Engine, Connection
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import inspect, text, select, func, distinct, Table, MetaData
from sqlalchemy import inspect, text, select, distinct, Table, MetaData

import sqlalchemy as sa
import yaml
Expand Down Expand Up @@ -247,6 +247,14 @@ def _get_approximate_row_count(self, table_name: str) -> Optional[int]:
"FROM pg_class "
"WHERE relname = :table_name"
).bindparams(table_name=table_name)
elif self.engine.dialect.name == 'mssql':
query = text(
"SELECT SUM(p.rows) AS estimate "
"FROM sys.tables t "
"INNER JOIN sys.partitions p ON t.object_id = p.object_id "
"WHERE t.name = :table_name "
"AND p.index_id IN (0, 1)"
).bindparams(table_name=table_name)
else:
query = text(
"SELECT table_rows "
Expand Down Expand Up @@ -352,7 +360,7 @@ def _process_table(

cardinality_ratio = cardinality / total_samples if total_samples > 0 else 1

is_string_type = col_type.upper().startswith(('VARCHAR', 'CHAR', 'TEXT', 'ENUM'))
is_string_type = col_type.upper().startswith(('VARCHAR', 'CHAR', 'TEXT', 'ENUM', 'NVARCHAR', 'NCHAR', 'NTEXT'))
looks_like_enum = self._looks_like_enum_column(col_name)
has_low_cardinality = cardinality_ratio < 0.3

Expand Down
49 changes: 41 additions & 8 deletions sam-sql-database-tool/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sqlalchemy as sa
from testcontainers.postgres import PostgresContainer
from testcontainers.mysql import MySqlContainer
from testcontainers.mssql import SqlServerContainer
from dataclasses import dataclass
from typing import Type, Callable

Expand Down Expand Up @@ -55,6 +56,32 @@ class DatabaseTestConfig:
),
id="mariadb",
),
pytest.param(
DatabaseTestConfig(
name="mssql-msodbc18",
container_class=SqlServerContainer,
image="mcr.microsoft.com/mssql/server:2022-latest",
connection_url_fn=lambda c: (
f"mssql+pyodbc://sa:{c.password}@"
f"{c.get_container_host_ip()}:{c.get_exposed_port(1433)}/master"
f"?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
),
),
id="mssql-msodbc18",
),
pytest.param(
DatabaseTestConfig(
name="mssql-freetds",
container_class=SqlServerContainer,
image="mcr.microsoft.com/mssql/server:2022-latest",
connection_url_fn=lambda c: (
f"mssql+pyodbc://sa:{c.password}@"
f"{c.get_container_host_ip()}:{c.get_exposed_port(1433)}/master"
f"?driver=FreeTDS&TrustServerCertificate=yes"
),
),
id="mssql-freetds",
),
],
)
def db_config(request):
Expand All @@ -65,14 +92,20 @@ def db_config(request):
@pytest.fixture(scope="session")
def database_container(db_config: DatabaseTestConfig):
"""Starts and stops a Docker container for the configured database."""
with db_config.container_class(
db_config.image, dbname="test_db", username="test_user", password="test_password"
) as container:
if db_config.name in ["mysql", "mariadb"]:
container.with_env("MYSQL_ROOT_PASSWORD", "root_password")
# Attach the config to the container object for easy access in other fixtures
container.db_config = db_config
yield container
if db_config.name.startswith("mssql"):
# MSSQL container has a different constructor signature
with db_config.container_class(db_config.image) as container:
container.db_config = db_config
yield container
else:
with db_config.container_class(
db_config.image, dbname="test_db", username="test_user", password="test_password"
) as container:
if db_config.name in ["mysql", "mariadb"]:
container.with_env("MYSQL_ROOT_PASSWORD", "root_password")
# Attach the config to the container object for easy access in other fixtures
container.db_config = db_config
yield container


@pytest.fixture(scope="function")
Expand Down
30 changes: 24 additions & 6 deletions sam-sql-database-tool/tests/integration/test_config_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def test_cache_ttl_seconds(self, db_tool_provider):
# Re-initialize the tool with a short TTL
tool_config_dict = db_tool_provider.tool_config.model_dump()
tool_config_dict['cache_ttl_seconds'] = 2

custom_config = DatabaseConfig(**tool_config_dict)
tool = SqlDatabaseTool(custom_config)
await tool.init(component=None, tool_config={})
Expand All @@ -62,9 +62,18 @@ async def test_cache_ttl_seconds(self, db_tool_provider):
assert "temp_col_for_ttl_test" not in initial_schema

# 2. Directly alter the database schema
# MSSQL uses different ALTER TABLE syntax (no COLUMN keyword)
dialect = tool.db_service.engine.dialect.name
if dialect == 'mssql':
add_col_sql = "ALTER TABLE users ADD temp_col_for_ttl_test INT"
drop_col_sql = "ALTER TABLE users DROP COLUMN temp_col_for_ttl_test"
else:
add_col_sql = "ALTER TABLE users ADD COLUMN temp_col_for_ttl_test INT"
drop_col_sql = "ALTER TABLE users DROP COLUMN temp_col_for_ttl_test"

conn = tool.db_service.engine.connect()
try:
conn.execute(pytest.importorskip("sqlalchemy").text("ALTER TABLE users ADD COLUMN temp_col_for_ttl_test INT;"))
conn.execute(pytest.importorskip("sqlalchemy").text(add_col_sql))
conn.commit()

# 3. Wait for the cache TTL to expire
Expand All @@ -79,13 +88,13 @@ async def test_cache_ttl_seconds(self, db_tool_provider):

# 6. Get the schema a third time; it should now be refreshed from the background job.
refreshed_schema = tool.db_service.get_optimized_schema_for_llm()

# 7. Assert that the new column is now present
assert "temp_col_for_ttl_test" in refreshed_schema, "Schema should be refreshed with the new column"

finally:
# Clean up the added column
conn.execute(pytest.importorskip("sqlalchemy").text("ALTER TABLE users DROP COLUMN temp_col_for_ttl_test;"))
conn.execute(pytest.importorskip("sqlalchemy").text(drop_col_sql))
conn.commit()
conn.close()
await tool.cleanup(component=None, tool_config={})
Expand All @@ -99,9 +108,18 @@ async def test_clear_cache(self, db_tool_provider):
assert "temp_col_for_clear_test" not in initial_schema

# 2. Directly alter the database schema
# MSSQL uses different ALTER TABLE syntax (no COLUMN keyword)
dialect = tool.db_service.engine.dialect.name
if dialect == 'mssql':
add_col_sql = "ALTER TABLE users ADD temp_col_for_clear_test INT"
drop_col_sql = "ALTER TABLE users DROP COLUMN temp_col_for_clear_test"
else:
add_col_sql = "ALTER TABLE users ADD COLUMN temp_col_for_clear_test INT"
drop_col_sql = "ALTER TABLE users DROP COLUMN temp_col_for_clear_test"

conn = tool.db_service.engine.connect()
try:
conn.execute(pytest.importorskip("sqlalchemy").text("ALTER TABLE users ADD COLUMN temp_col_for_clear_test INT;"))
conn.execute(pytest.importorskip("sqlalchemy").text(add_col_sql))
conn.commit()

# 3. Manually clear the cache
Expand All @@ -115,6 +133,6 @@ async def test_clear_cache(self, db_tool_provider):

finally:
# Clean up the added column
conn.execute(pytest.importorskip("sqlalchemy").text("ALTER TABLE users DROP COLUMN temp_col_for_clear_test;"))
conn.execute(pytest.importorskip("sqlalchemy").text(drop_col_sql))
conn.commit()
conn.close()
6 changes: 4 additions & 2 deletions sam-sql-database-tool/tests/integration/test_sql_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def test_cache_hit_performance(self, db_tool_provider: SqlDatabaseTool):

first_call_start = time.time()
schema1 = db_tool_provider.db_service.get_optimized_schema_for_llm()
first_call_time = time.time() - first_call_start
_first_call_time = time.time() - first_call_start

second_call_start = time.time()
schema2 = db_tool_provider.db_service.get_optimized_schema_for_llm()
Expand Down Expand Up @@ -140,10 +140,12 @@ async def test_multi_table_join_query(self, db_tool_provider: SqlDatabaseTool):

async def test_aggregation_with_join(self, db_tool_provider: SqlDatabaseTool):
"""Test a query that uses aggregation across joined tables."""
# Cast rating to Float to ensure consistent decimal results across databases
# (MSSQL returns integer for AVG of integer columns)
query = (
sa.select(
categories.c.name,
sa.func.avg(reviews.c.rating).label("average_rating")
sa.func.avg(sa.cast(reviews.c.rating, sa.Float)).label("average_rating")
)
.join(product_categories, categories.c.id == product_categories.c.category_id)
.join(products, product_categories.c.product_id == products.c.id)
Expand Down
13 changes: 7 additions & 6 deletions sam-sql-database-tool/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@
metadata = sa.MetaData()

# Define tables using SQLAlchemy's generic types
# Note: autoincrement=False is required to allow explicit ID inserts on MSSQL
users = sa.Table('users', metadata,
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('id', sa.Integer, primary_key=True, autoincrement=False),
sa.Column('name', sa.String(100)),
sa.Column('email', sa.String(100), unique=True),
sa.Column('created_at', sa.DateTime)
)

categories = sa.Table('categories', metadata,
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('id', sa.Integer, primary_key=True, autoincrement=False),
sa.Column('name', sa.String(100), unique=True)
)

products = sa.Table('products', metadata,
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('id', sa.Integer, primary_key=True, autoincrement=False),
sa.Column('name', sa.String(100)),
sa.Column('description', sa.Text),
sa.Column('price', sa.Float)
Expand All @@ -29,22 +30,22 @@
)

orders = sa.Table('orders', metadata,
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('id', sa.Integer, primary_key=True, autoincrement=False),
sa.Column('user_id', sa.Integer, sa.ForeignKey('users.id', ondelete='CASCADE')),
sa.Column('order_date', sa.Date),
sa.Column('status', sa.String(50))
)

order_items = sa.Table('order_items', metadata,
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('id', sa.Integer, primary_key=True, autoincrement=False),
sa.Column('order_id', sa.Integer, sa.ForeignKey('orders.id', ondelete='CASCADE')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('products.id', ondelete='CASCADE')),
sa.Column('quantity', sa.Integer),
sa.Column('price_per_unit', sa.Float)
)

reviews = sa.Table('reviews', metadata,
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('id', sa.Integer, primary_key=True, autoincrement=False),
sa.Column('product_id', sa.Integer, sa.ForeignKey('products.id', ondelete='CASCADE')),
sa.Column('user_id', sa.Integer, sa.ForeignKey('users.id', ondelete='CASCADE')),
sa.Column('rating', sa.Integer),
Expand Down
Loading