diff --git a/.github/workflows/build-plugin.yaml b/.github/workflows/build-plugin.yaml index be4783b..0cc64f3 100644 --- a/.github/workflows/build-plugin.yaml +++ b/.github/workflows/build-plugin.yaml @@ -72,6 +72,20 @@ jobs: working-directory: ${{ inputs.plugin_directory }} run: hatch -v env create hatch-test.py3.13 + - name: Install ODBC Driver 18 for SQL Server + if: inputs.plugin_directory == 'sam-sql-database-tool' + run: | + curl https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor | sudo tee /usr/share/keyrings/microsoft-prod.gpg > /dev/null + curl https://packages.microsoft.com/config/ubuntu/$(lsb_release -rs)/prod.list | sudo tee /etc/apt/sources.list.d/mssql-release.list + sudo apt-get update + sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18 unixodbc-dev + + - name: Install FreeTDS ODBC Driver + if: inputs.plugin_directory == 'sam-sql-database-tool' + run: | + sudo apt-get install -y freetds-dev tdsodbc unixodbc-dev + sudo odbcinst -i -d -f /usr/share/tdsodbc/odbcinst.ini + - name: Test Plugin ${{ inputs.plugin_directory }} with Python 3.13 id: test continue-on-error: true diff --git a/sam-sql-database-tool/README.md b/sam-sql-database-tool/README.md index 47d6cd0..0b0d73b 100644 --- a/sam-sql-database-tool/README.md +++ b/sam-sql-database-tool/README.md @@ -7,7 +7,7 @@ Unlike the `sam-sql-database` agent, which provides a complete Natural-Language- ## Key Features - **Dynamic Tool Creation**: Define custom SQL query tools directly in your agent's YAML configuration. Each tool instance is completely independent. -- **Multi-Database Support**: Natively supports PostgreSQL, MySQL, and MariaDB. +- **Multi-Database Support**: Natively supports PostgreSQL, MySQL, MariaDB, and MSSQL. - **Dedicated Connections**: Each tool instance creates its own dedicated database connection, allowing for fine-grained configuration. - **Flexible Schema Handling**: - Automatic schema detection and summarization for LLM prompting. @@ -54,7 +54,17 @@ tools: - `tool_name`: (Required) The function name the LLM will use to call the tool. - `tool_description`: (Optional) A clear description for the LLM explaining what the tool does. -- `connection_string`: (Required) The full database connection string (e.g., `postgresql+psycopg2://user:password@host:port/dbname` for PostgreSQL, or `mysql+pymysql://user:password@host:port/dbname` for MySQL/MariaDB). It is highly recommended to use a single environment variable for the entire string. +- `connection_string`: (Required) The full database connection string. It is highly recommended to use a single environment variable for the entire string. Supported formats: + - **PostgreSQL**: `postgresql+psycopg2://user:password@host:port/dbname` + - **MySQL**: `mysql+pymysql://user:password@host:port/dbname` + - **MariaDB**: `mysql+pymysql://user:password@host:port/dbname` + - **MSSQL (FreeTDS - Recommended)**: `mssql+pyodbc://user:password@host:port/dbname?driver=FreeTDS&TrustServerCertificate=yes` + - Open-source driver with simpler installation: `sudo apt-get install freetds-dev freetds-bin tdsodbc && sudo odbcinst -i -d -f /usr/share/tdsodbc/odbcinst.ini` + - Works well for standard SQL operations. + - **MSSQL (Microsoft ODBC)**: `mssql+pyodbc://user:password@host:port/dbname?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes` + - Official Microsoft driver with full feature support (Azure AD auth, Always Encrypted, etc.). + - Requires ODBC Driver 17 or 18 installed on the host system. + - Use `TrustServerCertificate=yes` for self-signed certificates or `Encrypt=no` to disable encryption. - `auto_detect_schema`: (Optional, default: `true`) If `true`, the plugin attempts to automatically detect the database schema. If `false`, you must provide `schema_summary_override`. - `schema_summary_override`: (Required if `auto_detect_schema` is `false`) A concise natural language summary of the schema, suitable for direct inclusion in an LLM prompt. - `max_enum_cardinality`: (Optional, default: `100`) Maximum number of distinct values to consider a column as an enum. Increase for columns like countries (190+), decrease for faster init times. diff --git a/sam-sql-database-tool/pyproject.toml b/sam-sql-database-tool/pyproject.toml index 7373c33..fe63474 100644 --- a/sam-sql-database-tool/pyproject.toml +++ b/sam-sql-database-tool/pyproject.toml @@ -10,7 +10,7 @@ type = "tool" [project] name = "sam_sql_database_tool" -version = "0.1.1" +version = "0.2.0" description = "A reusable SQL database tool for Solace Agent Mesh." readme = "README.md" authors = [ @@ -21,6 +21,7 @@ dependencies = [ "SQLAlchemy==2.0.40", "PyMySQL==1.1.2", "psycopg2-binary==2.9.10", + "pyodbc==5.3.0", "pydantic==2.11.9", "PyYAML==6.0.2" ] @@ -50,5 +51,5 @@ dependencies = [ "solace-agent-mesh>=1.4.9,<2.0.0", "testing.postgresql", "psycopg2-binary", - "testcontainers[mysql,postgres]>=4.4.0" + "testcontainers[mysql,postgres,mssql]>=4.4.0" ] diff --git a/sam-sql-database-tool/src/sam_sql_database_tool/services/database_service.py b/sam-sql-database-tool/src/sam_sql_database_tool/services/database_service.py index d4d4e10..79ad300 100644 --- a/sam-sql-database-tool/src/sam_sql_database_tool/services/database_service.py +++ b/sam-sql-database-tool/src/sam_sql_database_tool/services/database_service.py @@ -8,7 +8,7 @@ from sqlalchemy.engine import Engine, Connection from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy import inspect, text, select, func, distinct, Table, MetaData +from sqlalchemy import inspect, text, select, distinct, Table, MetaData import sqlalchemy as sa import yaml @@ -247,6 +247,14 @@ def _get_approximate_row_count(self, table_name: str) -> Optional[int]: "FROM pg_class " "WHERE relname = :table_name" ).bindparams(table_name=table_name) + elif self.engine.dialect.name == 'mssql': + query = text( + "SELECT SUM(p.rows) AS estimate " + "FROM sys.tables t " + "INNER JOIN sys.partitions p ON t.object_id = p.object_id " + "WHERE t.name = :table_name " + "AND p.index_id IN (0, 1)" + ).bindparams(table_name=table_name) else: query = text( "SELECT table_rows " @@ -352,7 +360,7 @@ def _process_table( cardinality_ratio = cardinality / total_samples if total_samples > 0 else 1 - is_string_type = col_type.upper().startswith(('VARCHAR', 'CHAR', 'TEXT', 'ENUM')) + is_string_type = col_type.upper().startswith(('VARCHAR', 'CHAR', 'TEXT', 'ENUM', 'NVARCHAR', 'NCHAR', 'NTEXT')) looks_like_enum = self._looks_like_enum_column(col_name) has_low_cardinality = cardinality_ratio < 0.3 diff --git a/sam-sql-database-tool/tests/conftest.py b/sam-sql-database-tool/tests/conftest.py index 282b18a..da99a9a 100644 --- a/sam-sql-database-tool/tests/conftest.py +++ b/sam-sql-database-tool/tests/conftest.py @@ -2,6 +2,7 @@ import sqlalchemy as sa from testcontainers.postgres import PostgresContainer from testcontainers.mysql import MySqlContainer +from testcontainers.mssql import SqlServerContainer from dataclasses import dataclass from typing import Type, Callable @@ -55,6 +56,32 @@ class DatabaseTestConfig: ), id="mariadb", ), + pytest.param( + DatabaseTestConfig( + name="mssql-msodbc18", + container_class=SqlServerContainer, + image="mcr.microsoft.com/mssql/server:2022-latest", + connection_url_fn=lambda c: ( + f"mssql+pyodbc://sa:{c.password}@" + f"{c.get_container_host_ip()}:{c.get_exposed_port(1433)}/master" + f"?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes" + ), + ), + id="mssql-msodbc18", + ), + pytest.param( + DatabaseTestConfig( + name="mssql-freetds", + container_class=SqlServerContainer, + image="mcr.microsoft.com/mssql/server:2022-latest", + connection_url_fn=lambda c: ( + f"mssql+pyodbc://sa:{c.password}@" + f"{c.get_container_host_ip()}:{c.get_exposed_port(1433)}/master" + f"?driver=FreeTDS&TrustServerCertificate=yes" + ), + ), + id="mssql-freetds", + ), ], ) def db_config(request): @@ -65,14 +92,20 @@ def db_config(request): @pytest.fixture(scope="session") def database_container(db_config: DatabaseTestConfig): """Starts and stops a Docker container for the configured database.""" - with db_config.container_class( - db_config.image, dbname="test_db", username="test_user", password="test_password" - ) as container: - if db_config.name in ["mysql", "mariadb"]: - container.with_env("MYSQL_ROOT_PASSWORD", "root_password") - # Attach the config to the container object for easy access in other fixtures - container.db_config = db_config - yield container + if db_config.name.startswith("mssql"): + # MSSQL container has a different constructor signature + with db_config.container_class(db_config.image) as container: + container.db_config = db_config + yield container + else: + with db_config.container_class( + db_config.image, dbname="test_db", username="test_user", password="test_password" + ) as container: + if db_config.name in ["mysql", "mariadb"]: + container.with_env("MYSQL_ROOT_PASSWORD", "root_password") + # Attach the config to the container object for easy access in other fixtures + container.db_config = db_config + yield container @pytest.fixture(scope="function") diff --git a/sam-sql-database-tool/tests/integration/test_config_features.py b/sam-sql-database-tool/tests/integration/test_config_features.py index 1945a19..1bf745d 100644 --- a/sam-sql-database-tool/tests/integration/test_config_features.py +++ b/sam-sql-database-tool/tests/integration/test_config_features.py @@ -52,7 +52,7 @@ async def test_cache_ttl_seconds(self, db_tool_provider): # Re-initialize the tool with a short TTL tool_config_dict = db_tool_provider.tool_config.model_dump() tool_config_dict['cache_ttl_seconds'] = 2 - + custom_config = DatabaseConfig(**tool_config_dict) tool = SqlDatabaseTool(custom_config) await tool.init(component=None, tool_config={}) @@ -62,9 +62,18 @@ async def test_cache_ttl_seconds(self, db_tool_provider): assert "temp_col_for_ttl_test" not in initial_schema # 2. Directly alter the database schema + # MSSQL uses different ALTER TABLE syntax (no COLUMN keyword) + dialect = tool.db_service.engine.dialect.name + if dialect == 'mssql': + add_col_sql = "ALTER TABLE users ADD temp_col_for_ttl_test INT" + drop_col_sql = "ALTER TABLE users DROP COLUMN temp_col_for_ttl_test" + else: + add_col_sql = "ALTER TABLE users ADD COLUMN temp_col_for_ttl_test INT" + drop_col_sql = "ALTER TABLE users DROP COLUMN temp_col_for_ttl_test" + conn = tool.db_service.engine.connect() try: - conn.execute(pytest.importorskip("sqlalchemy").text("ALTER TABLE users ADD COLUMN temp_col_for_ttl_test INT;")) + conn.execute(pytest.importorskip("sqlalchemy").text(add_col_sql)) conn.commit() # 3. Wait for the cache TTL to expire @@ -79,13 +88,13 @@ async def test_cache_ttl_seconds(self, db_tool_provider): # 6. Get the schema a third time; it should now be refreshed from the background job. refreshed_schema = tool.db_service.get_optimized_schema_for_llm() - + # 7. Assert that the new column is now present assert "temp_col_for_ttl_test" in refreshed_schema, "Schema should be refreshed with the new column" finally: # Clean up the added column - conn.execute(pytest.importorskip("sqlalchemy").text("ALTER TABLE users DROP COLUMN temp_col_for_ttl_test;")) + conn.execute(pytest.importorskip("sqlalchemy").text(drop_col_sql)) conn.commit() conn.close() await tool.cleanup(component=None, tool_config={}) @@ -99,9 +108,18 @@ async def test_clear_cache(self, db_tool_provider): assert "temp_col_for_clear_test" not in initial_schema # 2. Directly alter the database schema + # MSSQL uses different ALTER TABLE syntax (no COLUMN keyword) + dialect = tool.db_service.engine.dialect.name + if dialect == 'mssql': + add_col_sql = "ALTER TABLE users ADD temp_col_for_clear_test INT" + drop_col_sql = "ALTER TABLE users DROP COLUMN temp_col_for_clear_test" + else: + add_col_sql = "ALTER TABLE users ADD COLUMN temp_col_for_clear_test INT" + drop_col_sql = "ALTER TABLE users DROP COLUMN temp_col_for_clear_test" + conn = tool.db_service.engine.connect() try: - conn.execute(pytest.importorskip("sqlalchemy").text("ALTER TABLE users ADD COLUMN temp_col_for_clear_test INT;")) + conn.execute(pytest.importorskip("sqlalchemy").text(add_col_sql)) conn.commit() # 3. Manually clear the cache @@ -115,6 +133,6 @@ async def test_clear_cache(self, db_tool_provider): finally: # Clean up the added column - conn.execute(pytest.importorskip("sqlalchemy").text("ALTER TABLE users DROP COLUMN temp_col_for_clear_test;")) + conn.execute(pytest.importorskip("sqlalchemy").text(drop_col_sql)) conn.commit() conn.close() diff --git a/sam-sql-database-tool/tests/integration/test_sql_tool.py b/sam-sql-database-tool/tests/integration/test_sql_tool.py index 7922aef..feb7e6a 100644 --- a/sam-sql-database-tool/tests/integration/test_sql_tool.py +++ b/sam-sql-database-tool/tests/integration/test_sql_tool.py @@ -90,7 +90,7 @@ async def test_cache_hit_performance(self, db_tool_provider: SqlDatabaseTool): first_call_start = time.time() schema1 = db_tool_provider.db_service.get_optimized_schema_for_llm() - first_call_time = time.time() - first_call_start + _first_call_time = time.time() - first_call_start second_call_start = time.time() schema2 = db_tool_provider.db_service.get_optimized_schema_for_llm() @@ -140,10 +140,12 @@ async def test_multi_table_join_query(self, db_tool_provider: SqlDatabaseTool): async def test_aggregation_with_join(self, db_tool_provider: SqlDatabaseTool): """Test a query that uses aggregation across joined tables.""" + # Cast rating to Float to ensure consistent decimal results across databases + # (MSSQL returns integer for AVG of integer columns) query = ( sa.select( categories.c.name, - sa.func.avg(reviews.c.rating).label("average_rating") + sa.func.avg(sa.cast(reviews.c.rating, sa.Float)).label("average_rating") ) .join(product_categories, categories.c.id == product_categories.c.category_id) .join(products, product_categories.c.product_id == products.c.id) diff --git a/sam-sql-database-tool/tests/test_data.py b/sam-sql-database-tool/tests/test_data.py index 7a883ba..577dffe 100644 --- a/sam-sql-database-tool/tests/test_data.py +++ b/sam-sql-database-tool/tests/test_data.py @@ -4,20 +4,21 @@ metadata = sa.MetaData() # Define tables using SQLAlchemy's generic types +# Note: autoincrement=False is required to allow explicit ID inserts on MSSQL users = sa.Table('users', metadata, - sa.Column('id', sa.Integer, primary_key=True), + sa.Column('id', sa.Integer, primary_key=True, autoincrement=False), sa.Column('name', sa.String(100)), sa.Column('email', sa.String(100), unique=True), sa.Column('created_at', sa.DateTime) ) categories = sa.Table('categories', metadata, - sa.Column('id', sa.Integer, primary_key=True), + sa.Column('id', sa.Integer, primary_key=True, autoincrement=False), sa.Column('name', sa.String(100), unique=True) ) products = sa.Table('products', metadata, - sa.Column('id', sa.Integer, primary_key=True), + sa.Column('id', sa.Integer, primary_key=True, autoincrement=False), sa.Column('name', sa.String(100)), sa.Column('description', sa.Text), sa.Column('price', sa.Float) @@ -29,14 +30,14 @@ ) orders = sa.Table('orders', metadata, - sa.Column('id', sa.Integer, primary_key=True), + sa.Column('id', sa.Integer, primary_key=True, autoincrement=False), sa.Column('user_id', sa.Integer, sa.ForeignKey('users.id', ondelete='CASCADE')), sa.Column('order_date', sa.Date), sa.Column('status', sa.String(50)) ) order_items = sa.Table('order_items', metadata, - sa.Column('id', sa.Integer, primary_key=True), + sa.Column('id', sa.Integer, primary_key=True, autoincrement=False), sa.Column('order_id', sa.Integer, sa.ForeignKey('orders.id', ondelete='CASCADE')), sa.Column('product_id', sa.Integer, sa.ForeignKey('products.id', ondelete='CASCADE')), sa.Column('quantity', sa.Integer), @@ -44,7 +45,7 @@ ) reviews = sa.Table('reviews', metadata, - sa.Column('id', sa.Integer, primary_key=True), + sa.Column('id', sa.Integer, primary_key=True, autoincrement=False), sa.Column('product_id', sa.Integer, sa.ForeignKey('products.id', ondelete='CASCADE')), sa.Column('user_id', sa.Integer, sa.ForeignKey('users.id', ondelete='CASCADE')), sa.Column('rating', sa.Integer),