Skip to content

feat(ibis): Added BE Support for MySQL SSL Connection #1024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from abc import ABC
from enum import Enum

from pydantic import BaseModel, Field, SecretStr
from starlette.status import (
Expand Down Expand Up @@ -99,6 +100,8 @@ class MySqlConnectionInfo(BaseModel):
database: SecretStr
user: SecretStr
password: SecretStr
ssl_mode: SecretStr | None = Field(alias="sslMode", default=None)
ssl_ca: SecretStr | None = Field(alias="sslCA", default=None)
kwargs: dict[str, str] | None = Field(
description="Additional keyword arguments to pass to PyMySQL", default=None
)
Expand Down Expand Up @@ -207,3 +210,9 @@ class UnprocessableEntityError(CustomHttpError):

class NotFoundError(CustomHttpError):
status_code = HTTP_404_NOT_FOUND


class SSLMode(str, Enum):
DISABLED = "disabled"
ENABLED = "enabled"
VERIFY_CA = "verify_ca"
40 changes: 37 additions & 3 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import base64
import ssl
from enum import Enum, StrEnum, auto
from json import loads
from typing import Optional

import ibis
from google.oauth2 import service_account
Expand All @@ -27,6 +29,7 @@
QuerySnowflakeDTO,
QueryTrinoDTO,
SnowflakeConnectionInfo,
SSLMode,
TrinoConnectionInfo,
)

Expand Down Expand Up @@ -130,15 +133,19 @@ def get_mssql_connection(cls, info: MSSqlConnectionInfo) -> BaseBackend:
**info.kwargs if info.kwargs else dict(),
)

@staticmethod
def get_mysql_connection(info: MySqlConnectionInfo) -> BaseBackend:
@classmethod
def get_mysql_connection(cls, info: MySqlConnectionInfo) -> BaseBackend:
ssl_context = cls._create_ssl_context(info)
kwargs = {"ssl": ssl_context} if ssl_context else {}
if info.kwargs:
kwargs.update(info.kwargs)
return ibis.mysql.connect(
host=info.host.get_secret_value(),
port=int(info.port.get_secret_value()),
database=info.database.get_secret_value(),
user=info.user.get_secret_value(),
password=info.password.get_secret_value(),
**info.kwargs if info.kwargs else dict(),
**kwargs,
)

@staticmethod
Expand Down Expand Up @@ -175,3 +182,30 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend:
@staticmethod
def _escape_special_characters_for_odbc(value: str) -> str:
return "{" + value.replace("}", "}}") + "}"

@staticmethod
def _create_ssl_context(info: ConnectionInfo) -> Optional[ssl.SSLContext]:
ssl_mode = (
info.ssl_mode.get_secret_value() if hasattr(info, "ssl_mode") else None
)

if ssl_mode == SSLMode.VERIFY_CA and not info.ssl_ca:
raise ValueError("SSL CA must be provided when SSL mode is VERIFY CA")

if not ssl_mode or ssl_mode == SSLMode.DISABLED:
return None

ctx = ssl.create_default_context()
ctx.check_hostname = False

if ssl_mode == SSLMode.ENABLED:
ctx.verify_mode = ssl.CERT_NONE
elif ssl_mode == SSLMode.VERIFY_CA:
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.load_verify_locations(
cadata=base64.b64decode(info.ssl_ca.get_secret_value()).decode("utf-8")
if info.ssl_ca
else None
)
Comment on lines +205 to +209
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for base64 decoding.

The base64 decoding of SSL CA certificate should be wrapped in a try-except block to handle potential decoding errors gracefully.

-            ctx.load_verify_locations(
-                cadata=base64.b64decode(info.ssl_ca.get_secret_value()).decode("utf-8")
-                if info.ssl_ca
-                else None
-            )
+            if info.ssl_ca:
+                try:
+                    ca_data = base64.b64decode(info.ssl_ca.get_secret_value()).decode("utf-8")
+                    ctx.load_verify_locations(cadata=ca_data)
+                except (base64.binascii.Error, UnicodeDecodeError) as e:
+                    raise ValueError(f"Invalid SSL CA certificate: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
ctx.load_verify_locations(
cadata=base64.b64decode(info.ssl_ca.get_secret_value()).decode("utf-8")
if info.ssl_ca
else None
)
if info.ssl_ca:
try:
ca_data = base64.b64decode(info.ssl_ca.get_secret_value()).decode("utf-8")
ctx.load_verify_locations(cadata=ca_data)
except (base64.binascii.Error, UnicodeDecodeError) as e:
raise ValueError(f"Invalid SSL CA certificate: {str(e)}")


return ctx
45 changes: 45 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import pandas as pd
import pytest
import sqlalchemy
from pymysql import OperationalError
from sqlalchemy import text
from testcontainers.mysql import MySqlContainer

from app.model import SSLMode
from app.model.validator import rules
from tests.conftest import file_path

Expand Down Expand Up @@ -111,6 +113,13 @@ def mysql(request) -> MySqlContainer:
return mysql


@pytest.fixture(scope="module")
def mysql_ssl_off(request) -> MySqlContainer:
mysql = MySqlContainer(image="mysql:8.0.40").with_command("--ssl=0").start()
request.addfinalizer(mysql.stop)
return mysql


async def test_query(client, manifest_str, mysql: MySqlContainer):
connection_info = _to_connection_info(mysql)
response = await client.post(
Expand Down Expand Up @@ -405,6 +414,42 @@ async def test_metadata_db_version(client, mysql: MySqlContainer):
assert response.text == '"8.0.40"'


@pytest.mark.parametrize(
"ssl_mode, expected_exception, expected_error",
[
(SSLMode.ENABLED, OperationalError, "Bad handshake"),
(
SSLMode.VERIFY_CA,
ValueError,
"SSL CA must be provided when SSL mode is VERIFY CA",
),
],
)
async def test_connection_invalid_ssl_mode(
client, mysql_ssl_off: MySqlContainer, ssl_mode, expected_exception, expected_error
):
connection_info = _to_connection_info(mysql_ssl_off)
connection_info["sslMode"] = ssl_mode

with pytest.raises(expected_exception) as excinfo:
await client.post(
url=f"{base_url}/metadata/version",
json={"connectionInfo": connection_info},
)
assert expected_error in str(excinfo.value)


async def test_connection_valid_ssl_mode(client, mysql_ssl_off: MySqlContainer):
connection_info = _to_connection_info(mysql_ssl_off)
connection_info["sslMode"] = SSLMode.DISABLED
response = await client.post(
url=f"{base_url}/metadata/version",
json={"connectionInfo": connection_info},
)
assert response.status_code == 200
assert response.text == '"8.0.40"'


def _to_connection_info(mysql: MySqlContainer):
return {
"host": mysql.get_container_host_ip(),
Expand Down
Loading