diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index b38645d8..a22f4c8e 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -1,15 +1,15 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be -// taken up in future - -#include "connection.h" -#include "connection_pool.h" -#include -#include -#include +#include "connection/connection.h" +#include "connection/connection_pool.h" #include +#include +#include +#include +#include +#include +#include #define SQL_COPT_SS_ACCESS_TOKEN 1256 // Custom attribute ID for access token #define SQL_MAX_SMALL_INT 32767 // Maximum value for SQLSMALLINT @@ -22,15 +22,18 @@ static SqlHandlePtr getEnvHandle() { DriverLoader::getInstance().loadDriver(); } SQLHANDLE env = nullptr; - SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); + SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, + &env); if (!SQL_SUCCEEDED(ret)) { ThrowStdException("Failed to allocate environment handle"); } - ret = SQLSetEnvAttr_ptr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0); + ret = SQLSetEnvAttr_ptr(env, SQL_ATTR_ODBC_VERSION, + reinterpret_cast(SQL_OV_ODBC3_80), 0); if (!SQL_SUCCEEDED(ret)) { ThrowStdException("Failed to set environment attributes"); } - return std::make_shared(static_cast(SQL_HANDLE_ENV), env); + return std::make_shared( + static_cast(SQL_HANDLE_ENV), env); }(); return envHandle; @@ -47,7 +50,7 @@ Connection::Connection(const std::wstring& conn_str, bool use_pool) } Connection::~Connection() { - disconnect(); // fallback if user forgets to disconnect + disconnect(); // fallback if user forgets to disconnect } // Allocates connection handle @@ -55,9 +58,11 @@ void Connection::allocateDbcHandle() { auto _envHandle = getEnvHandle(); SQLHANDLE dbc = nullptr; LOG("Allocate SQL Connection Handle"); - SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_DBC, _envHandle->get(), &dbc); + SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_DBC, _envHandle->get(), + &dbc); checkError(ret); - _dbcHandle = std::make_shared(static_cast(SQL_HANDLE_DBC), dbc); + _dbcHandle = std::make_shared( + static_cast(SQL_HANDLE_DBC), dbc); } void Connection::connect(const py::dict& attrs_before) { @@ -71,7 +76,7 @@ void Connection::connect(const py::dict& attrs_before) { } } SQLWCHAR* connStrPtr; -#if defined(__APPLE__) || defined(__linux__) // macOS/Linux specific handling +#if defined(__APPLE__) || defined(__linux__) // macOS/Linux handling LOG("Creating connection string buffer for macOS/Linux"); std::vector connStrBuffer = WStringToSQLWCHAR(_connStr); // Ensure the buffer is null-terminated @@ -94,15 +99,16 @@ void Connection::disconnect() { LOG("Disconnecting from database"); SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get()); checkError(ret); - _dbcHandle.reset(); // triggers SQLFreeHandle via destructor, if last owner - } - else { + // triggers SQLFreeHandle via destructor, if last owner + _dbcHandle.reset(); + } else { LOG("No connection handle to disconnect"); } } -// TODO: Add an exception class in C++ for error handling, DB spec compliant -void Connection::checkError(SQLRETURN ret) const{ +// TODO(microsoft): Add an exception class in C++ for error handling, +// DB spec compliant +void Connection::checkError(SQLRETURN ret) const { if (!SQL_SUCCEEDED(ret)) { ErrorInfo err = SQLCheckError_Wrap(SQL_HANDLE_DBC, _dbcHandle, ret); std::string errorMsg = WideToUTF8(err.ddbcErrorMsg); @@ -116,7 +122,8 @@ void Connection::commit() { } updateLastUsed(); LOG("Committing transaction"); - SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_COMMIT); + SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), + SQL_COMMIT); checkError(ret); } @@ -126,7 +133,8 @@ void Connection::rollback() { } updateLastUsed(); LOG("Rolling back transaction"); - SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_ROLLBACK); + SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), + SQL_ROLLBACK); checkError(ret); } @@ -136,9 +144,11 @@ void Connection::setAutocommit(bool enable) { } SQLINTEGER value = enable ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF; LOG("Setting SQL Connection Attribute"); - SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_AUTOCOMMIT, reinterpret_cast(static_cast(value)), 0); + SQLRETURN ret = SQLSetConnectAttr_ptr( + _dbcHandle->get(), SQL_ATTR_AUTOCOMMIT, + reinterpret_cast(static_cast(value)), 0); checkError(ret); - if(value == SQL_AUTOCOMMIT_ON) { + if (value == SQL_AUTOCOMMIT_ON) { LOG("SQL Autocommit set to True"); } else { LOG("SQL Autocommit set to False"); @@ -153,7 +163,9 @@ bool Connection::getAutocommit() const { LOG("Get SQL Connection Attribute"); SQLINTEGER value; SQLINTEGER string_length; - SQLRETURN ret = SQLGetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_AUTOCOMMIT, &value, sizeof(value), &string_length); + SQLRETURN ret = SQLGetConnectAttr_ptr(_dbcHandle->get(), + SQL_ATTR_AUTOCOMMIT, &value, + sizeof(value), &string_length); checkError(ret); return value == SQL_AUTOCOMMIT_ON; } @@ -165,59 +177,61 @@ SqlHandlePtr Connection::allocStatementHandle() { updateLastUsed(); LOG("Allocating statement handle"); SQLHANDLE stmt = nullptr; - SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbcHandle->get(), &stmt); + SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbcHandle->get(), + &stmt); checkError(ret); - return std::make_shared(static_cast(SQL_HANDLE_STMT), stmt); + return std::make_shared( + static_cast(SQL_HANDLE_STMT), stmt); } SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) { LOG("Setting SQL attribute"); - //SQLPOINTER ptr = nullptr; - //SQLINTEGER length = 0; - static std::string buffer; // to hold sensitive data temporarily - + // SQLPOINTER ptr = nullptr; + // SQLINTEGER length = 0; + static std::string buffer; // to hold sensitive data temporarily + if (py::isinstance(value)) { // Get the integer value - long long longValue = value.cast(); - + int64_t longValue = value.cast(); + SQLRETURN ret = SQLSetConnectAttr_ptr( - _dbcHandle->get(), - attribute, - (SQLPOINTER)(SQLULEN)longValue, + _dbcHandle->get(), + attribute, + reinterpret_cast(static_cast(longValue)), SQL_IS_INTEGER); - + if (!SQL_SUCCEEDED(ret)) { LOG("Failed to set attribute"); - } - else { + } else { LOG("Set attribute successfully"); } return ret; - } - else if (py::isinstance(value)) { + } else if (py::isinstance(value)) { try { - static std::vector wstr_buffers; // Keep buffers alive + // Keep buffers alive + static std::vector wstr_buffers; std::string utf8_str = value.cast(); - + // Convert to wide string std::wstring wstr = Utf8ToWString(utf8_str); if (wstr.empty() && !utf8_str.empty()) { LOG("Failed to convert string value to wide string"); return SQL_ERROR; } - + // Limit static buffer growth for memory safety constexpr size_t MAX_BUFFER_COUNT = 100; if (wstr_buffers.size() >= MAX_BUFFER_COUNT) { - // Remove oldest 50% of entries when limit reached - wstr_buffers.erase(wstr_buffers.begin(), wstr_buffers.begin() + (MAX_BUFFER_COUNT / 2)); + // Remove oldest 50% of entries when limit reached + wstr_buffers.erase(wstr_buffers.begin(), + wstr_buffers.begin() + (MAX_BUFFER_COUNT / 2)); } - + wstr_buffers.push_back(wstr); - + SQLPOINTER ptr; SQLINTEGER length; - + #if defined(__APPLE__) || defined(__linux__) // For macOS/Linux, convert wstring to SQLWCHAR buffer std::vector sqlwcharBuffer = WStringToSQLWCHAR(wstr); @@ -225,60 +239,63 @@ SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) { LOG("Failed to convert wide string to SQLWCHAR buffer"); return SQL_ERROR; } - + ptr = sqlwcharBuffer.data(); - length = static_cast(sqlwcharBuffer.size() * sizeof(SQLWCHAR)); + length = static_cast( + sqlwcharBuffer.size() * sizeof(SQLWCHAR)); #else // On Windows, wchar_t and SQLWCHAR are the same size ptr = const_cast(wstr_buffers.back().c_str()); - length = static_cast(wstr.length() * sizeof(SQLWCHAR)); + length = static_cast( + wstr.length() * sizeof(SQLWCHAR)); #endif - - SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), attribute, ptr, length); + + SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), + attribute, ptr, length); if (!SQL_SUCCEEDED(ret)) { LOG("Failed to set string attribute"); - } - else { + } else { LOG("Set string attribute successfully"); } return ret; - } - catch (const std::exception& e) { - LOG("Exception during string attribute setting: " + std::string(e.what())); + } catch (const std::exception& e) { + LOG("Exception during string attribute setting: " + + std::string(e.what())); return SQL_ERROR; } - } - else if (py::isinstance(value) || py::isinstance(value)) { + } else if (py::isinstance(value) || + py::isinstance(value)) { try { static std::vector buffers; std::string binary_data = value.cast(); - + // Limit static buffer growth constexpr size_t MAX_BUFFER_COUNT = 100; if (buffers.size() >= MAX_BUFFER_COUNT) { - // Remove oldest 50% of entries when limit reached - buffers.erase(buffers.begin(), buffers.begin() + (MAX_BUFFER_COUNT / 2)); + // Remove oldest 50% of entries when limit reached + buffers.erase(buffers.begin(), + buffers.begin() + (MAX_BUFFER_COUNT / 2)); } - + buffers.emplace_back(std::move(binary_data)); SQLPOINTER ptr = const_cast(buffers.back().c_str()); - SQLINTEGER length = static_cast(buffers.back().size()); - - SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), attribute, ptr, length); + SQLINTEGER length = static_cast( + buffers.back().size()); + + SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), + attribute, ptr, length); if (!SQL_SUCCEEDED(ret)) { LOG("Failed to set attribute with binary data"); - } - else { + } else { LOG("Set attribute successfully with binary data"); } return ret; - } - catch (const std::exception& e) { - LOG("Exception during binary attribute setting: " + std::string(e.what())); + } catch (const std::exception& e) { + LOG("Exception during binary attribute setting: " + + std::string(e.what())); return SQL_ERROR; } - } - else { + } else { LOG("Unsupported attribute value type"); return SQL_ERROR; } @@ -294,10 +311,12 @@ void Connection::applyAttrsBefore(const py::dict& attrs) { } // Apply all supported attributes - SQLRETURN ret = setAttribute(key, py::reinterpret_borrow(item.second)); + SQLRETURN ret = setAttribute( + key, py::reinterpret_borrow(item.second)); if (!SQL_SUCCEEDED(ret)) { std::string attrName = std::to_string(key); - std::string errorMsg = "Failed to set attribute " + attrName + " before connect"; + std::string errorMsg = "Failed to set attribute " + attrName + + " before connect"; ThrowStdException(errorMsg); } } @@ -308,8 +327,9 @@ bool Connection::isAlive() const { ThrowStdException("Connection handle not allocated"); } SQLUINTEGER status; - SQLRETURN ret = SQLGetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_CONNECTION_DEAD, - &status, 0, nullptr); + SQLRETURN ret = SQLGetConnectAttr_ptr(_dbcHandle->get(), + SQL_ATTR_CONNECTION_DEAD, + &status, 0, nullptr); return SQL_SUCCEEDED(ret) && status == SQL_CD_FALSE; } @@ -340,11 +360,14 @@ std::chrono::steady_clock::time_point Connection::lastUsed() const { return _lastUsed; } -ConnectionHandle::ConnectionHandle(const std::string& connStr, bool usePool, const py::dict& attrsBefore) +ConnectionHandle::ConnectionHandle(const std::string& connStr, + bool usePool, + const py::dict& attrsBefore) : _usePool(usePool) { _connStr = Utf8ToWString(connStr); if (_usePool) { - _conn = ConnectionPoolManager::getInstance().acquireConnection(_connStr, attrsBefore); + _conn = ConnectionPoolManager::getInstance().acquireConnection( + _connStr, attrsBefore); } else { _conn = std::make_shared(_connStr, false); _conn->connect(attrsBefore); @@ -408,16 +431,17 @@ py::object Connection::getInfo(SQLUSMALLINT infoType) const { if (!_dbcHandle) { ThrowStdException("Connection handle not allocated"); } - + // First call with NULL buffer to get required length SQLSMALLINT requiredLen = 0; - SQLRETURN ret = SQLGetInfo_ptr(_dbcHandle->get(), infoType, NULL, 0, &requiredLen); - + SQLRETURN ret = SQLGetInfo_ptr(_dbcHandle->get(), infoType, NULL, 0, + &requiredLen); + if (!SQL_SUCCEEDED(ret)) { checkError(ret); return py::none(); } - + // For zero-length results if (requiredLen == 0) { py::dict result; @@ -426,37 +450,39 @@ py::object Connection::getInfo(SQLUSMALLINT infoType) const { result["info_type"] = infoType; return result; } - - // Cap buffer allocation to SQL_MAX_SMALL_INT to prevent excessive memory usage + + // Cap buffer allocation to SQL_MAX_SMALL_INT to prevent excessive + // memory usage SQLSMALLINT allocSize = requiredLen + 10; if (allocSize > SQL_MAX_SMALL_INT) { allocSize = SQL_MAX_SMALL_INT; } std::vector buffer(allocSize, 0); // Extra padding for safety - + // Get the actual data - avoid using std::min SQLSMALLINT bufferSize = requiredLen + 10; if (bufferSize > SQL_MAX_SMALL_INT) { bufferSize = SQL_MAX_SMALL_INT; } - + SQLSMALLINT returnedLen = 0; - ret = SQLGetInfo_ptr(_dbcHandle->get(), infoType, buffer.data(), bufferSize, &returnedLen); - + ret = SQLGetInfo_ptr(_dbcHandle->get(), infoType, buffer.data(), + bufferSize, &returnedLen); + if (!SQL_SUCCEEDED(ret)) { checkError(ret); return py::none(); } - + // Create a dictionary with the raw data py::dict result; - + // IMPORTANT: Pass exactly what SQLGetInfo returned // No null-terminator manipulation, just pass the raw data result["data"] = py::bytes(buffer.data(), returnedLen); result["length"] = returnedLen; result["info_type"] = infoType; - + return result; } @@ -471,26 +497,31 @@ void ConnectionHandle::setAttr(int attribute, py::object value) { if (!_conn) { ThrowStdException("Connection not established"); } - + // Use existing setAttribute with better error handling - SQLRETURN ret = _conn->setAttribute(static_cast(attribute), value); + SQLRETURN ret = _conn->setAttribute( + static_cast(attribute), value); if (!SQL_SUCCEEDED(ret)) { // Get detailed error information from ODBC try { - ErrorInfo errorInfo = SQLCheckError_Wrap(SQL_HANDLE_DBC, _conn->getDbcHandle(), ret); - - std::string errorMsg = "Failed to set connection attribute " + std::to_string(attribute); + ErrorInfo errorInfo = SQLCheckError_Wrap( + SQL_HANDLE_DBC, _conn->getDbcHandle(), ret); + + std::string errorMsg = "Failed to set connection attribute " + + std::to_string(attribute); if (!errorInfo.ddbcErrorMsg.empty()) { // Convert wstring to string for concatenation - std::string ddbcErrorStr = WideToUTF8(errorInfo.ddbcErrorMsg); + std::string ddbcErrorStr = WideToUTF8( + errorInfo.ddbcErrorMsg); errorMsg += ": " + ddbcErrorStr; } - + LOG("Connection setAttribute failed: {}", errorMsg); ThrowStdException(errorMsg); } catch (...) { // Fallback to generic error if detailed error retrieval fails - std::string errorMsg = "Failed to set connection attribute " + std::to_string(attribute); + std::string errorMsg = "Failed to set connection attribute " + + std::to_string(attribute); LOG("Connection setAttribute failed: {}", errorMsg); ThrowStdException(errorMsg); } diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index 68c2a216..885be39a 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -1,18 +1,17 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be -// taken up in future. - #pragma once -#include "ddbc_bindings.h" +#include +#include +#include "../ddbc_bindings.h" // Represents a single ODBC database connection. // Manages connection handles. // Note: This class does NOT implement pooling logic directly. class Connection { -public: + public: Connection(const std::wstring& connStr, bool fromPool); ~Connection(); @@ -50,7 +49,7 @@ class Connection { // Add getter for DBC handle for error reporting const SqlHandlePtr& getDbcHandle() const { return _dbcHandle; } -private: + private: void allocateDbcHandle(); void checkError(SQLRETURN ret) const; void applyAttrsBefore(const py::dict& attrs_before); @@ -63,8 +62,9 @@ class Connection { }; class ConnectionHandle { -public: - ConnectionHandle(const std::string& connStr, bool usePool, const py::dict& attrsBefore = py::dict()); + public: + ConnectionHandle(const std::string& connStr, bool usePool, + const py::dict& attrsBefore = py::dict()); ~ConnectionHandle(); void close(); @@ -78,8 +78,8 @@ class ConnectionHandle { // Get information about the driver and data source py::object getInfo(SQLUSMALLINT infoType) const; -private: + private: std::shared_ptr _conn; bool _usePool; std::wstring _connStr; -}; \ No newline at end of file +}; diff --git a/mssql_python/pybind/connection/connection_pool.cpp b/mssql_python/pybind/connection/connection_pool.cpp index 60dd5415..cc2c4825 100644 --- a/mssql_python/pybind/connection/connection_pool.cpp +++ b/mssql_python/pybind/connection/connection_pool.cpp @@ -1,16 +1,17 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be -// taken up in future. - -#include "connection_pool.h" +#include "connection/connection_pool.h" #include +#include +#include ConnectionPool::ConnectionPool(size_t max_size, int idle_timeout_secs) - : _max_size(max_size), _idle_timeout_secs(idle_timeout_secs), _current_size(0) {} + : _max_size(max_size), _idle_timeout_secs(idle_timeout_secs), + _current_size(0) {} -std::shared_ptr ConnectionPool::acquire(const std::wstring& connStr, const py::dict& attrs_before) { +std::shared_ptr ConnectionPool::acquire( + const std::wstring& connStr, const py::dict& attrs_before) { std::vector> to_disconnect; std::shared_ptr valid_conn = nullptr; { @@ -21,7 +22,8 @@ std::shared_ptr ConnectionPool::acquire(const std::wstring& connStr, // Phase 1: Remove stale connections, collect for later disconnect _pool.erase(std::remove_if(_pool.begin(), _pool.end(), [&](const std::shared_ptr& conn) { - auto idle_time = std::chrono::duration_cast(now - conn->lastUsed()).count(); + auto idle_time = std::chrono::duration_cast< + std::chrono::seconds>(now - conn->lastUsed()).count(); if (idle_time > _idle_timeout_secs) { to_disconnect.push_back(conn); return true; @@ -30,7 +32,8 @@ std::shared_ptr ConnectionPool::acquire(const std::wstring& connStr, }), _pool.end()); size_t pruned = before - _pool.size(); - _current_size = (_current_size >= pruned) ? (_current_size - pruned) : 0; + _current_size = (_current_size >= pruned) ? + (_current_size - pruned) : 0; // Phase 2: Attempt to reuse healthy connections while (!_pool.empty()) { @@ -56,7 +59,8 @@ std::shared_ptr ConnectionPool::acquire(const std::wstring& connStr, valid_conn->connect(attrs_before); ++_current_size; } else if (!valid_conn) { - throw std::runtime_error("ConnectionPool::acquire: pool size limit reached"); + throw std::runtime_error( + "ConnectionPool::acquire: pool size limit reached"); } } @@ -76,8 +80,7 @@ void ConnectionPool::release(std::shared_ptr conn) { if (_pool.size() < _max_size) { conn->updateLastUsed(); _pool.push_back(conn); - } - else { + } else { conn->disconnect(); if (_current_size > 0) --_current_size; } @@ -97,7 +100,8 @@ void ConnectionPool::close() { try { conn->disconnect(); } catch (const std::exception& ex) { - LOG("ConnectionPool::close: disconnect failed: {}", ex.what()); + LOG("ConnectionPool::close: disconnect failed: {}", + ex.what()); } } } @@ -107,18 +111,21 @@ ConnectionPoolManager& ConnectionPoolManager::getInstance() { return manager; } -std::shared_ptr ConnectionPoolManager::acquireConnection(const std::wstring& connStr, const py::dict& attrs_before) { +std::shared_ptr ConnectionPoolManager::acquireConnection( + const std::wstring& connStr, const py::dict& attrs_before) { std::lock_guard lock(_manager_mutex); auto& pool = _pools[connStr]; if (!pool) { LOG("Creating new connection pool"); - pool = std::make_shared(_default_max_size, _default_idle_secs); + pool = std::make_shared(_default_max_size, + _default_idle_secs); } return pool->acquire(connStr, attrs_before); } -void ConnectionPoolManager::returnConnection(const std::wstring& conn_str, const std::shared_ptr conn) { +void ConnectionPoolManager::returnConnection( + const std::wstring& conn_str, const std::shared_ptr conn) { std::lock_guard lock(_manager_mutex); if (_pools.find(conn_str) != _pools.end()) { _pools[conn_str]->release((conn)); diff --git a/mssql_python/pybind/connection/connection_pool.h b/mssql_python/pybind/connection/connection_pool.h index dc2de5a8..7e1c315a 100644 --- a/mssql_python/pybind/connection/connection_pool.h +++ b/mssql_python/pybind/connection/connection_pool.h @@ -1,25 +1,28 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be -// taken up in future. +#ifndef MSSQL_PYTHON_CONNECTION_POOL_H_ +#define MSSQL_PYTHON_CONNECTION_POOL_H_ #pragma once +#include #include -#include #include #include #include -#include -#include "connection.h" +#include +#include "connection/connection.h" -// Manages a fixed-size pool of reusable database connections for a single connection string +// Manages a fixed-size pool of reusable database connections for a +// single connection string class ConnectionPool { -public: + public: ConnectionPool(size_t max_size, int idle_timeout_secs); // Acquires a connection from the pool or creates a new one if under limit - std::shared_ptr acquire(const std::wstring& connStr, const py::dict& attrs_before = py::dict()); + std::shared_ptr acquire( + const std::wstring& connStr, + const py::dict& attrs_before = py::dict()); // Returns a connection to the pool for reuse void release(std::shared_ptr conn); @@ -27,33 +30,36 @@ class ConnectionPool { // Closes all connections in the pool, releasing resources void close(); -private: - size_t _max_size; // Maximum number of connections allowed - int _idle_timeout_secs; // Idle time before connections are considered stale + private: + size_t _max_size; // Maximum number of connections allowed + int _idle_timeout_secs; // Idle time before connections are stale size_t _current_size = 0; std::deque> _pool; // Available connections - std::mutex _mutex; // Mutex for thread-safe access + std::mutex _mutex; // Mutex for thread-safe access }; // Singleton manager that handles multiple pools keyed by connection string class ConnectionPoolManager { -public: + public: // Returns the singleton instance of the manager static ConnectionPoolManager& getInstance(); void configure(int max_size, int idle_timeout); // Gets a connection from the appropriate pool (creates one if none exists) - std::shared_ptr acquireConnection(const std::wstring& conn_str, const py::dict& attrs_before = py::dict()); + std::shared_ptr acquireConnection( + const std::wstring& conn_str, + const py::dict& attrs_before = py::dict()); // Returns a connection to its original pool - void returnConnection(const std::wstring& conn_str, std::shared_ptr conn); + void returnConnection(const std::wstring& conn_str, + std::shared_ptr conn); // Closes all pools and their connections void closePools(); -private: - ConnectionPoolManager() = default; + private: + ConnectionPoolManager() = default; ~ConnectionPoolManager() = default; // Map from connection string to connection pool @@ -63,8 +69,10 @@ class ConnectionPoolManager { std::mutex _manager_mutex; size_t _default_max_size = 10; int _default_idle_secs = 300; - + // Prevent copying ConnectionPoolManager(const ConnectionPoolManager&) = delete; ConnectionPoolManager& operator=(const ConnectionPoolManager&) = delete; }; + +#endif // MSSQL_PYTHON_CONNECTION_POOL_H_ diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 176724a4..eeb5bb37 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -1,23 +1,23 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be -// taken up in future. +// INFO|TODO - Note that is file is Windows specific right now. Making it +// arch agnostic will be taken up in future. #pragma once -#include // pybind11.h must be the first include - https://pybind11.readthedocs.io/en/latest/basics.html#header-and-namespace-conventions +// pybind11.h must be the first include +#include #include #include #include #include // Add this line for datetime support #include -namespace py = pybind11; -using namespace pybind11::literals; - -#include #include -#include +#include +#include +namespace py = pybind11; +using py::literals::operator""_a; #ifdef _WIN32 // Windows-specific headers @@ -39,7 +39,8 @@ inline std::vector WStringToSQLWCHAR(const std::wstring& str) { return result; } -inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) { +inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, + size_t length = SQL_NTS) { if (!sqlwStr) return std::wstring(); if (length == SQL_NTS) { @@ -67,10 +68,12 @@ constexpr uint32_t UNICODE_REPLACEMENT_CHAR = 0xFFFD; // (excludes surrogate halves and values beyond U+10FFFF) inline bool IsValidUnicodeScalar(uint32_t cp) { return cp <= UNICODE_MAX_CODEPOINT && - !(cp >= UNICODE_SURROGATE_HIGH_START && cp <= UNICODE_SURROGATE_LOW_END); + !(cp >= UNICODE_SURROGATE_HIGH_START && + cp <= UNICODE_SURROGATE_LOW_END); } -inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) { +inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, + size_t length = SQL_NTS) { if (!sqlwStr) return std::wstring(); if (length == SQL_NTS) { size_t i = 0; @@ -80,27 +83,34 @@ inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = S std::wstring result; result.reserve(length); if constexpr (sizeof(SQLWCHAR) == 2) { - for (size_t i = 0; i < length; ) { // Use a manual increment to handle skipping + // Use a manual increment to handle skipping + for (size_t i = 0; i < length; ) { uint16_t wc = static_cast(sqlwStr[i]); // Check for high surrogate and valid low surrogate - if (wc >= UNICODE_SURROGATE_HIGH_START && wc <= UNICODE_SURROGATE_HIGH_END && (i + 1 < length)) { + if (wc >= UNICODE_SURROGATE_HIGH_START && + wc <= UNICODE_SURROGATE_HIGH_END && (i + 1 < length)) { uint16_t low = static_cast(sqlwStr[i + 1]); - if (low >= UNICODE_SURROGATE_LOW_START && low <= UNICODE_SURROGATE_LOW_END) { + if (low >= UNICODE_SURROGATE_LOW_START && + low <= UNICODE_SURROGATE_LOW_END) { // Combine into a single code point - uint32_t cp = (((wc - UNICODE_SURROGATE_HIGH_START) << 10) | (low - UNICODE_SURROGATE_LOW_START)) + 0x10000; + uint32_t cp = (((wc - UNICODE_SURROGATE_HIGH_START) << 10) | + (low - UNICODE_SURROGATE_LOW_START)) + + 0x10000; result.push_back(static_cast(cp)); - i += 2; // Move past both surrogates + i += 2; // Move past both surrogates continue; } } - // If we reach here, it's not a valid surrogate pair or is a BMP character. - // Check if it's a valid scalar and append, otherwise append replacement char. + // If we reach here, it's not a valid surrogate pair or is a BMP + // character. Check if it's a valid scalar and append, otherwise + // append replacement char. if (IsValidUnicodeScalar(wc)) { result.push_back(static_cast(wc)); } else { - result.push_back(static_cast(UNICODE_REPLACEMENT_CHAR)); + result.push_back( + static_cast(UNICODE_REPLACEMENT_CHAR)); } - ++i; // Move to the next code unit + ++i; // Move to the next code unit } } else { // SQLWCHAR is UTF-32, so just copy with validation @@ -109,7 +119,8 @@ inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = S if (IsValidUnicodeScalar(cp)) { result.push_back(static_cast(cp)); } else { - result.push_back(static_cast(UNICODE_REPLACEMENT_CHAR)); + result.push_back( + static_cast(UNICODE_REPLACEMENT_CHAR)); } } } @@ -132,8 +143,10 @@ inline std::vector WStringToSQLWCHAR(const std::wstring& str) { } else { // Encode as surrogate pair cp -= 0x10000; - SQLWCHAR high = static_cast((cp >> 10) + UNICODE_SURROGATE_HIGH_START); - SQLWCHAR low = static_cast((cp & 0x3FF) + UNICODE_SURROGATE_LOW_START); + SQLWCHAR high = static_cast( + (cp >> 10) + UNICODE_SURROGATE_HIGH_START); + SQLWCHAR low = static_cast( + (cp & 0x3FF) + UNICODE_SURROGATE_LOW_START); result.push_back(high); result.push_back(low); } @@ -145,18 +158,19 @@ inline std::vector WStringToSQLWCHAR(const std::wstring& str) { if (IsValidUnicodeScalar(cp)) { result.push_back(static_cast(cp)); } else { - result.push_back(static_cast(UNICODE_REPLACEMENT_CHAR)); + result.push_back( + static_cast(UNICODE_REPLACEMENT_CHAR)); } } } - result.push_back(0); // null terminator + result.push_back(0); // null terminator return result; } #endif #if defined(__APPLE__) || defined(__linux__) -#include "unix_utils.h" // For Unix-specific Unicode encoding fixes -#include "unix_buffers.h" // For Unix-specific buffer handling +#include "unix_utils.h" // Unix-specific fixes +#include "unix_buffers.h" // Unix-specific buffers #endif //------------------------------------------------------------------------------------------------- @@ -164,39 +178,63 @@ inline std::vector WStringToSQLWCHAR(const std::wstring& str) { //------------------------------------------------------------------------------------------------- // Handle APIs -typedef SQLRETURN (SQL_API* SQLAllocHandleFunc)(SQLSMALLINT, SQLHANDLE, SQLHANDLE*); -typedef SQLRETURN (SQL_API* SQLSetEnvAttrFunc)(SQLHANDLE, SQLINTEGER, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (SQL_API* SQLSetConnectAttrFunc)(SQLHDBC, SQLINTEGER, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (SQL_API* SQLSetStmtAttrFunc)(SQLHSTMT, SQLINTEGER, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (SQL_API* SQLGetConnectAttrFunc)(SQLHDBC, SQLINTEGER, SQLPOINTER, SQLINTEGER, SQLINTEGER*); +typedef SQLRETURN (SQL_API* SQLAllocHandleFunc)(SQLSMALLINT, SQLHANDLE, + SQLHANDLE*); +typedef SQLRETURN (SQL_API* SQLSetEnvAttrFunc)(SQLHANDLE, SQLINTEGER, + SQLPOINTER, SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLSetConnectAttrFunc)(SQLHDBC, SQLINTEGER, + SQLPOINTER, SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLSetStmtAttrFunc)(SQLHSTMT, SQLINTEGER, + SQLPOINTER, SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLGetConnectAttrFunc)(SQLHDBC, SQLINTEGER, + SQLPOINTER, SQLINTEGER, + SQLINTEGER*); // Connection and Execution APIs -typedef SQLRETURN (SQL_API* SQLDriverConnectFunc)(SQLHANDLE, SQLHWND, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, - SQLSMALLINT, SQLSMALLINT*, SQLUSMALLINT); -typedef SQLRETURN (SQL_API* SQLExecDirectFunc)(SQLHANDLE, SQLWCHAR*, SQLINTEGER); -typedef SQLRETURN (SQL_API* SQLPrepareFunc)(SQLHANDLE, SQLWCHAR*, SQLINTEGER); -typedef SQLRETURN (SQL_API* SQLBindParameterFunc)(SQLHANDLE, SQLUSMALLINT, SQLSMALLINT, SQLSMALLINT, - SQLSMALLINT, SQLULEN, SQLSMALLINT, SQLPOINTER, SQLLEN, - SQLLEN*); +typedef SQLRETURN (SQL_API* SQLDriverConnectFunc)(SQLHANDLE, SQLHWND, + SQLWCHAR*, SQLSMALLINT, + SQLWCHAR*, SQLSMALLINT, + SQLSMALLINT*, + SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLExecDirectFunc)(SQLHANDLE, SQLWCHAR*, + SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLPrepareFunc)(SQLHANDLE, SQLWCHAR*, + SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLBindParameterFunc)(SQLHANDLE, SQLUSMALLINT, + SQLSMALLINT, SQLSMALLINT, + SQLSMALLINT, SQLULEN, + SQLSMALLINT, SQLPOINTER, + SQLLEN, SQLLEN*); typedef SQLRETURN (SQL_API* SQLExecuteFunc)(SQLHANDLE); typedef SQLRETURN (SQL_API* SQLRowCountFunc)(SQLHSTMT, SQLLEN*); -typedef SQLRETURN (SQL_API* SQLSetDescFieldFunc)(SQLHDESC, SQLSMALLINT, SQLSMALLINT, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (SQL_API* SQLGetStmtAttrFunc)(SQLHSTMT, SQLINTEGER, SQLPOINTER, SQLINTEGER, SQLINTEGER*); +typedef SQLRETURN (SQL_API* SQLSetDescFieldFunc)(SQLHDESC, SQLSMALLINT, + SQLSMALLINT, SQLPOINTER, + SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLGetStmtAttrFunc)(SQLHSTMT, SQLINTEGER, + SQLPOINTER, SQLINTEGER, + SQLINTEGER*); // Data retrieval APIs typedef SQLRETURN (SQL_API* SQLFetchFunc)(SQLHANDLE); -typedef SQLRETURN (SQL_API* SQLFetchScrollFunc)(SQLHANDLE, SQLSMALLINT, SQLLEN); -typedef SQLRETURN (SQL_API* SQLGetDataFunc)(SQLHANDLE, SQLUSMALLINT, SQLSMALLINT, SQLPOINTER, SQLLEN, - SQLLEN*); +typedef SQLRETURN (SQL_API* SQLFetchScrollFunc)(SQLHANDLE, SQLSMALLINT, + SQLLEN); +typedef SQLRETURN (SQL_API* SQLGetDataFunc)(SQLHANDLE, SQLUSMALLINT, + SQLSMALLINT, SQLPOINTER, + SQLLEN, SQLLEN*); typedef SQLRETURN (SQL_API* SQLNumResultColsFunc)(SQLHSTMT, SQLSMALLINT*); -typedef SQLRETURN (SQL_API* SQLBindColFunc)(SQLHSTMT, SQLUSMALLINT, SQLSMALLINT, SQLPOINTER, SQLLEN, - SQLLEN*); -typedef SQLRETURN (SQL_API* SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR*, SQLSMALLINT, - SQLSMALLINT*, SQLSMALLINT*, SQLULEN*, SQLSMALLINT*, - SQLSMALLINT*); +typedef SQLRETURN (SQL_API* SQLBindColFunc)(SQLHSTMT, SQLUSMALLINT, + SQLSMALLINT, SQLPOINTER, + SQLLEN, SQLLEN*); +typedef SQLRETURN (SQL_API* SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, + SQLWCHAR*, SQLSMALLINT, + SQLSMALLINT*, SQLSMALLINT*, + SQLULEN*, SQLSMALLINT*, + SQLSMALLINT*); typedef SQLRETURN (SQL_API* SQLMoreResultsFunc)(SQLHSTMT); -typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER, - SQLSMALLINT, SQLSMALLINT*, SQLPOINTER); +typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, + SQLUSMALLINT, SQLPOINTER, + SQLSMALLINT, SQLSMALLINT*, + SQLPOINTER); typedef SQLRETURN (*SQLTablesFunc)( SQLHSTMT StatementHandle, SQLWCHAR* CatalogName, @@ -208,27 +246,45 @@ typedef SQLRETURN (*SQLTablesFunc)( SQLWCHAR* TableType, SQLSMALLINT NameLength4 ); - typedef SQLRETURN (SQL_API* SQLGetTypeInfoFunc)(SQLHSTMT, SQLSMALLINT); -typedef SQLRETURN (SQL_API* SQLProceduresFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, - SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); -typedef SQLRETURN (SQL_API* SQLForeignKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, - SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, - SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); -typedef SQLRETURN (SQL_API* SQLPrimaryKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, - SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); -typedef SQLRETURN (SQL_API* SQLSpecialColumnsFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR*, SQLSMALLINT, - SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, - SQLUSMALLINT, SQLUSMALLINT); -typedef SQLRETURN (SQL_API* SQLStatisticsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, - SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, - SQLUSMALLINT, SQLUSMALLINT); -typedef SQLRETURN (SQL_API* SQLColumnsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, - SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, - SQLWCHAR*, SQLSMALLINT); -typedef SQLRETURN (SQL_API* SQLGetInfoFunc)(SQLHDBC, SQLUSMALLINT, SQLPOINTER, SQLSMALLINT, SQLSMALLINT*); +typedef SQLRETURN (SQL_API* SQLGetTypeInfoFunc)(SQLHSTMT, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLProceduresFunc)(SQLHSTMT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLForeignKeysFunc)(SQLHSTMT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLPrimaryKeysFunc)(SQLHSTMT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLSpecialColumnsFunc)(SQLHSTMT, SQLUSMALLINT, + SQLWCHAR*, SQLSMALLINT, + SQLWCHAR*, SQLSMALLINT, + SQLWCHAR*, SQLSMALLINT, + SQLUSMALLINT, + SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLStatisticsFunc)(SQLHSTMT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLUSMALLINT, + SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLColumnsFunc)(SQLHSTMT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLGetInfoFunc)(SQLHDBC, SQLUSMALLINT, + SQLPOINTER, SQLSMALLINT, + SQLSMALLINT*); // Transaction APIs -typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, + SQLSMALLINT); // Disconnect/free APIs typedef SQLRETURN (SQL_API* SQLFreeHandleFunc)(SQLSMALLINT, SQLHANDLE); @@ -236,10 +292,15 @@ typedef SQLRETURN (SQL_API* SQLDisconnectFunc)(SQLHDBC); typedef SQLRETURN (SQL_API* SQLFreeStmtFunc)(SQLHSTMT, SQLUSMALLINT); // Diagnostic APIs -typedef SQLRETURN (SQL_API* SQLGetDiagRecFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT, SQLWCHAR*, SQLINTEGER*, - SQLWCHAR*, SQLSMALLINT, SQLSMALLINT*); +typedef SQLRETURN (SQL_API* SQLGetDiagRecFunc)(SQLSMALLINT, SQLHANDLE, + SQLSMALLINT, SQLWCHAR*, + SQLINTEGER*, SQLWCHAR*, + SQLSMALLINT, SQLSMALLINT*); -typedef SQLRETURN (SQL_API* SQLDescribeParamFunc)(SQLHSTMT, SQLUSMALLINT, SQLSMALLINT*, SQLULEN*, SQLSMALLINT*, SQLSMALLINT*); +typedef SQLRETURN (SQL_API* SQLDescribeParamFunc)(SQLHSTMT, SQLUSMALLINT, + SQLSMALLINT*, SQLULEN*, + SQLSMALLINT*, + SQLSMALLINT*); // DAE APIs typedef SQLRETURN (SQL_API* SQLParamDataFunc)(SQLHSTMT, SQLPOINTER*); @@ -336,23 +397,25 @@ DriverHandle LoadDriverOrThrowException(); //------------------------------------------------------------------------------------------------- // DriverLoader (Singleton) // -// Ensures the ODBC driver and all function pointers are loaded exactly once across the process. -// This avoids redundant work and ensures thread-safe, centralized initialization. +// Ensures the ODBC driver and all function pointers are loaded exactly once +// across the process. +// This avoids redundant work and ensures thread-safe, centralized +// initialization. // // Not copyable or assignable. //------------------------------------------------------------------------------------------------- class DriverLoader { - public: - static DriverLoader& getInstance(); - void loadDriver(); - private: - DriverLoader(); - DriverLoader(const DriverLoader&) = delete; - DriverLoader& operator=(const DriverLoader&) = delete; - - bool m_driverLoaded; - std::once_flag m_onceFlag; - }; + public: + static DriverLoader& getInstance(); + void loadDriver(); + private: + DriverLoader(); + DriverLoader(const DriverLoader&) = delete; + DriverLoader& operator=(const DriverLoader&) = delete; + + bool m_driverLoaded; + std::once_flag m_onceFlag; +}; //------------------------------------------------------------------------------------------------- // SqlHandle @@ -361,40 +424,48 @@ class DriverLoader { // Use `std::shared_ptr` (alias: SqlHandlePtr) for shared ownership. //------------------------------------------------------------------------------------------------- class SqlHandle { - public: - SqlHandle(SQLSMALLINT type, SQLHANDLE rawHandle); - ~SqlHandle(); - SQLHANDLE get() const; - SQLSMALLINT type() const; - void free(); - private: - SQLSMALLINT _type; - SQLHANDLE _handle; - }; + public: + SqlHandle(SQLSMALLINT type, SQLHANDLE rawHandle); + ~SqlHandle(); + SQLHANDLE get() const; + SQLSMALLINT type() const; + void free(); + private: + SQLSMALLINT _type; + SQLHANDLE _handle; +}; using SqlHandlePtr = std::shared_ptr; -// This struct is used to relay error info obtained from SQLDiagRec API to the Python module +// This struct is used to relay error info obtained from SQLDiagRec API to the +// Python module struct ErrorInfo { std::wstring sqlState; std::wstring ddbcErrorMsg; }; -ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode); +ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, + SQLRETURN retcode); inline std::string WideToUTF8(const std::wstring& wstr) { if (wstr.empty()) return {}; #if defined(_WIN32) - int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), static_cast(wstr.size()), nullptr, 0, nullptr, nullptr); + int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), + static_cast(wstr.size()), + nullptr, 0, nullptr, nullptr); if (size_needed == 0) return {}; std::string result(size_needed, 0); - int converted = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), static_cast(wstr.size()), result.data(), size_needed, nullptr, nullptr); + int converted = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), + static_cast(wstr.size()), + result.data(), size_needed, + nullptr, nullptr); if (converted == 0) return {}; return result; #else // Manual UTF-32 to UTF-8 conversion for macOS/Linux std::string utf8_string; - utf8_string.reserve(wstr.size() * 4); // Reserve enough space for worst case (4 bytes per character) - + // Reserve enough space for worst case (4 bytes per character) + utf8_string.reserve(wstr.size() * 4); + for (wchar_t wc : wstr) { uint32_t code_point = static_cast(wc); @@ -407,14 +478,19 @@ inline std::string WideToUTF8(const std::wstring& wstr) { utf8_string += static_cast(0x80 | (code_point & 0x3F)); } else if (code_point <= 0xFFFF) { // 3-byte UTF-8 sequence - utf8_string += static_cast(0xE0 | ((code_point >> 12) & 0x0F)); - utf8_string += static_cast(0x80 | ((code_point >> 6) & 0x3F)); + utf8_string += static_cast(0xE0 | + ((code_point >> 12) & 0x0F)); + utf8_string += static_cast(0x80 | + ((code_point >> 6) & 0x3F)); utf8_string += static_cast(0x80 | (code_point & 0x3F)); } else if (code_point <= 0x10FFFF) { // 4-byte UTF-8 sequence for characters like emojis (e.g., U+1F604) - utf8_string += static_cast(0xF0 | ((code_point >> 18) & 0x07)); - utf8_string += static_cast(0x80 | ((code_point >> 12) & 0x3F)); - utf8_string += static_cast(0x80 | ((code_point >> 6) & 0x3F)); + utf8_string += static_cast(0xF0 | + ((code_point >> 18) & 0x07)); + utf8_string += static_cast(0x80 | + ((code_point >> 12) & 0x3F)); + utf8_string += static_cast(0x80 | + ((code_point >> 6) & 0x3F)); utf8_string += static_cast(0x80 | (code_point & 0x3F)); } } @@ -425,13 +501,17 @@ inline std::string WideToUTF8(const std::wstring& wstr) { inline std::wstring Utf8ToWString(const std::string& str) { if (str.empty()) return {}; #if defined(_WIN32) - int size_needed = MultiByteToWideChar(CP_UTF8, 0, str.data(), static_cast(str.size()), nullptr, 0); + int size_needed = MultiByteToWideChar(CP_UTF8, 0, str.data(), + static_cast(str.size()), + nullptr, 0); if (size_needed == 0) { LOG("MultiByteToWideChar failed."); return {}; } std::wstring result(size_needed, 0); - int converted = MultiByteToWideChar(CP_UTF8, 0, str.data(), static_cast(str.size()), result.data(), size_needed); + int converted = MultiByteToWideChar(CP_UTF8, 0, str.data(), + static_cast(str.size()), + result.data(), size_needed); if (converted == 0) return {}; return result; #else @@ -442,26 +522,26 @@ inline std::wstring Utf8ToWString(const std::string& str) { // Thread-safe decimal separator accessor class class ThreadSafeDecimalSeparator { -private: + private: std::string value; mutable std::mutex mutex; -public: + public: // Constructor with default value ThreadSafeDecimalSeparator() : value(".") {} - + // Set the decimal separator with thread safety void set(const std::string& separator) { std::lock_guard lock(mutex); value = separator; } - + // Get the decimal separator with thread safety std::string get() const { std::lock_guard lock(mutex); return value; } - + // Returns whether the current separator is different from the default "." bool isCustomSeparator() const { std::lock_guard lock(mutex); diff --git a/mssql_python/pybind/unix_buffers.h b/mssql_python/pybind/unix_buffers.h index 57039ac8..b130d23d 100644 --- a/mssql_python/pybind/unix_buffers.h +++ b/mssql_python/pybind/unix_buffers.h @@ -9,11 +9,11 @@ #pragma once -#include -#include -#include #include #include +#include +#include +#include namespace unix_buffers { @@ -26,15 +26,15 @@ constexpr size_t UCS_LENGTH = 2; * handling memory allocation and conversion to std::wstring. */ class SQLWCHARBuffer { -private: + private: std::unique_ptr buffer; size_t buffer_size; -public: + public: /** * Constructor allocates a buffer of the specified size */ - SQLWCHARBuffer(size_t size) : buffer_size(size) { + explicit SQLWCHARBuffer(size_t size) : buffer_size(size) { buffer = std::make_unique(size); // Initialize to zero for (size_t i = 0; i < size; i++) { @@ -62,7 +62,7 @@ class SQLWCHARBuffer { */ std::wstring toString(SQLSMALLINT length = -1) const { std::wstring result; - + // If length is provided, use it if (length > 0) { for (SQLSMALLINT i = 0; i < length; i++) { @@ -70,7 +70,7 @@ class SQLWCHARBuffer { } return result; } - + // Otherwise, read until null terminator for (size_t i = 0; i < buffer_size; i++) { if (buffer[i] == 0) { @@ -78,7 +78,7 @@ class SQLWCHARBuffer { } result.push_back(static_cast(buffer[i])); } - + return result; } }; @@ -88,53 +88,55 @@ class SQLWCHARBuffer { * Similar to the error list handling in the Python PoC _check_ret function */ class DiagnosticRecords { -private: + private: struct Record { std::wstring sqlState; std::wstring message; SQLINTEGER nativeError; }; - + std::vector records; -public: - void addRecord(const std::wstring& sqlState, const std::wstring& message, SQLINTEGER nativeError) { + public: + void addRecord(const std::wstring& sqlState, + const std::wstring& message, SQLINTEGER nativeError) { records.push_back({sqlState, message, nativeError}); } - + bool empty() const { return records.empty(); } - + std::wstring getSQLState() const { if (!records.empty()) { return records[0].sqlState; } - return L"HY000"; // General error + return L"HY000"; // General error } - + std::wstring getFirstErrorMessage() const { if (!records.empty()) { return records[0].message; } return L"Unknown error"; } - + std::wstring getFullErrorMessage() const { if (records.empty()) { return L"No error information available"; } - + std::wstring fullMessage = records[0].message; - + // Add additional error messages if there are any for (size_t i = 1; i < records.size(); i++) { - fullMessage += L"; [" + records[i].sqlState + L"] " + records[i].message; + fullMessage += L"; [" + records[i].sqlState + L"] " + + records[i].message; } - + return fullMessage; } - + size_t size() const { return records.size(); } @@ -147,23 +149,23 @@ class DiagnosticRecords { inline std::wstring UCS_dec(const SQLWCHAR* buffer, size_t maxLength = 0) { std::wstring result; size_t i = 0; - + while (true) { // Break if we've reached the maximum length if (maxLength > 0 && i >= maxLength) { break; } - + // Break if we've reached a null terminator if (buffer[i] == 0) { break; } - + result.push_back(static_cast(buffer[i])); i++; } - + return result; } -} // namespace unix_buffers +} // namespace unix_buffers diff --git a/mssql_python/pybind/unix_utils.cpp b/mssql_python/pybind/unix_utils.cpp index c98a9e09..3fd325bd 100644 --- a/mssql_python/pybind/unix_utils.cpp +++ b/mssql_python/pybind/unix_utils.cpp @@ -6,17 +6,24 @@ // between SQLWCHAR, std::wstring, and UTF-8 strings to bridge encoding // differences specific to macOS. +#include "unix_utils.h" +#include +#include +#include +#include + #if defined(__APPLE__) || defined(__linux__) // Constants for character encoding const char* kOdbcEncoding = "utf-16-le"; // ODBC uses UTF-16LE for SQLWCHAR -const size_t kUcsLength = 2; // SQLWCHAR is 2 bytes on all platforms +const size_t kUcsLength = 2; // SQLWCHAR is 2 bytes on all platforms -// TODO: Make Logger a separate module and import it across the project +// TODO(microsoft): Make Logger a separate module and import it across project template void LOG(const std::string& formatString, Args&&... args) { - py::gil_scoped_acquire gil; // <---- this ensures safe Python API usage + py::gil_scoped_acquire gil; // this ensures safe Python API usage - py::object logger = py::module_::import("mssql_python.logging_config").attr("get_logger")(); + py::object logger = py::module_::import("mssql_python.logging_config") + .attr("get_logger")(); if (py::isinstance(logger)) return; try { @@ -24,7 +31,8 @@ void LOG(const std::string& formatString, Args&&... args) { if constexpr (sizeof...(args) == 0) { logger.attr("debug")(py::str(ddbcFormatString)); } else { - py::str message = py::str(ddbcFormatString).format(std::forward(args)...); + py::str message = py::str(ddbcFormatString) + .format(std::forward(args)...); logger.attr("debug")(message); } } catch (const std::exception& e) { @@ -33,32 +41,38 @@ void LOG(const std::string& formatString, Args&&... args) { } // Function to convert SQLWCHAR strings to std::wstring on macOS -std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) { +std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, + size_t length = SQL_NTS) { if (!sqlwStr) return std::wstring(); - + if (length == SQL_NTS) { // Determine length if not provided size_t i = 0; while (sqlwStr[i] != 0) ++i; length = i; } - + // Create a UTF-16LE byte array from the SQLWCHAR array std::vector utf16Bytes(length * kUcsLength); for (size_t i = 0; i < length; ++i) { // Copy each SQLWCHAR (2 bytes) to the byte array memcpy(&utf16Bytes[i * kUcsLength], &sqlwStr[i], kUcsLength); } - + // Convert UTF-16LE to std::wstring (UTF-32 on macOS) try { // Use C++11 codecvt to convert between UTF-16LE and wstring - std::wstring_convert> converter; - return converter.from_bytes(reinterpret_cast(utf16Bytes.data()), - reinterpret_cast(utf16Bytes.data() + utf16Bytes.size())); + std::wstring_convert> + converter; + return converter.from_bytes( + reinterpret_cast(utf16Bytes.data()), + reinterpret_cast(utf16Bytes.data() + + utf16Bytes.size())); } catch (const std::exception& e) { // Log a warning about using fallback conversion - LOG("Warning: Using fallback string conversion on macOS. Character data might be inexact."); + LOG("Warning: Using fallback string conversion on macOS. " + "Character data might be inexact."); // Fallback to character-by-character conversion if codecvt fails std::wstring result; result.reserve(length); @@ -73,20 +87,25 @@ std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) std::vector WStringToSQLWCHAR(const std::wstring& str) { try { // Convert wstring (UTF-32 on macOS) to UTF-16LE bytes - std::wstring_convert> converter; + std::wstring_convert> + converter; std::string utf16Bytes = converter.to_bytes(str); - + // Convert the bytes to SQLWCHAR array - std::vector result(utf16Bytes.size() / kUcsLength + 1, 0); // +1 for null terminator + std::vector result(utf16Bytes.size() / kUcsLength + 1, + 0); // +1 for null terminator for (size_t i = 0; i < utf16Bytes.size() / kUcsLength; ++i) { memcpy(&result[i], &utf16Bytes[i * kUcsLength], kUcsLength); } return result; } catch (const std::exception& e) { // Log a warning about using fallback conversion - LOG("Warning: Using fallback conversion for std::wstring to SQLWCHAR on macOS. Character data might be inexact."); + LOG("Warning: Using fallback conversion for std::wstring to " + "SQLWCHAR on macOS. Character data might be inexact."); // Fallback to simple casting if codecvt fails - std::vector result(str.size() + 1, 0); // +1 for null terminator + std::vector result(str.size() + 1, + 0); // +1 for null terminator for (size_t i = 0; i < str.size(); ++i) { result[i] = static_cast(str[i]); } @@ -98,7 +117,7 @@ std::vector WStringToSQLWCHAR(const std::wstring& str) { // based on your ctypes UCS_dec implementation std::string SQLWCHARToUTF8String(const SQLWCHAR* buffer) { if (!buffer) return ""; - + std::vector utf16Bytes; size_t i = 0; while (buffer[i] != 0) { @@ -108,14 +127,19 @@ std::string SQLWCHARToUTF8String(const SQLWCHAR* buffer) { utf16Bytes.push_back(bytes[1]); i++; } - + try { - std::wstring_convert> converter; - return converter.to_bytes(reinterpret_cast(utf16Bytes.data()), - reinterpret_cast(utf16Bytes.data() + utf16Bytes.size())); + std::wstring_convert> + converter; + return converter.to_bytes( + reinterpret_cast(utf16Bytes.data()), + reinterpret_cast(utf16Bytes.data() + + utf16Bytes.size())); } catch (const std::exception& e) { // Log a warning about using fallback conversion - LOG("Warning: Using fallback conversion for SQLWCHAR to UTF-8 on macOS. Character data might be inexact."); + LOG("Warning: Using fallback conversion for SQLWCHAR to UTF-8 " + "on macOS. Character data might be inexact."); // Simple fallback conversion std::string result; for (size_t j = 0; j < i; ++j) { diff --git a/mssql_python/pybind/unix_utils.h b/mssql_python/pybind/unix_utils.h index cad35e74..61994ced 100644 --- a/mssql_python/pybind/unix_utils.h +++ b/mssql_python/pybind/unix_utils.h @@ -8,13 +8,13 @@ #pragma once -#include -#include -#include -#include +#include #include #include -#include +#include +#include +#include +#include namespace py = pybind11; diff --git a/tests/test_010_pybind_functions.py b/tests/test_010_pybind_functions.py new file mode 100644 index 00000000..e41246dd --- /dev/null +++ b/tests/test_010_pybind_functions.py @@ -0,0 +1,669 @@ +""" +This file contains tests for the pybind C++ functions in ddbc_bindings module. +These tests exercise the C++ code paths without mocking to provide real code coverage. + +Functions tested: +- Architecture and module info +- Utility functions (GetDriverPathCpp, ThrowStdException) +- Data structures (ParamInfo, NumericData, ErrorInfo, DateTimeOffset) +- SQL functions (DDBCSQLExecDirect, DDBCSQLExecute, etc.) +- Connection pooling functions +- Error handling functions +- Threading safety tests +- Unix-specific utility functions (when available) +""" + +import pytest +import platform +import threading +import os + +# Import ddbc_bindings with error handling +try: + import mssql_python.ddbc_bindings as ddbc + DDBC_AVAILABLE = True +except ImportError as e: + print(f"Warning: ddbc_bindings not available: {e}") + DDBC_AVAILABLE = False + ddbc = None + +from mssql_python.exceptions import ( + InterfaceError, ProgrammingError, DatabaseError, + OperationalError, DataError +) + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestPybindModuleInfo: + """Test module information and architecture detection.""" + + def test_module_architecture_attribute(self): + """Test that the module exposes architecture information.""" + assert hasattr(ddbc, 'ARCHITECTURE') + + arch = getattr(ddbc, 'ARCHITECTURE') + assert isinstance(arch, str) + assert len(arch) > 0 + + def test_architecture_consistency(self): + """Test that architecture attributes are consistent.""" + arch = getattr(ddbc, 'ARCHITECTURE') + # Valid architectures for Windows, Linux, and macOS + valid_architectures = [ + 'x64', 'x86', 'arm64', 'win64', # Windows + 'x86_64', 'i386', 'aarch64', # Linux + 'arm64', 'x86_64', 'universal2' # macOS (arm64/Intel/Universal) + ] + assert arch in valid_architectures, f"Unknown architecture: {arch}" + + def test_module_docstring(self): + """Test that the module has proper documentation.""" + # Module may not have __doc__ attribute set, which is acceptable + doc = getattr(ddbc, '__doc__', None) + if doc is not None: + assert isinstance(doc, str) + # Just verify the module loads and has expected attributes + assert hasattr(ddbc, 'ARCHITECTURE') + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestUtilityFunctions: + """Test C++ utility functions exposed to Python.""" + + def test_get_driver_path_cpp(self): + """Test GetDriverPathCpp function.""" + try: + # Function requires a driver name argument + driver_path = ddbc.GetDriverPathCpp("ODBC Driver 18 for SQL Server") + assert isinstance(driver_path, str) + # Driver path should not be empty if found + if driver_path: + assert len(driver_path) > 0 + except Exception as e: + # On some systems, driver might not be available + error_msg = str(e).lower() + assert any(keyword in error_msg for keyword in [ + "driver not found", "cannot find", "not available", + "incompatible", "not supported" + ]) + + def test_throw_std_exception(self): + """Test ThrowStdException function.""" + with pytest.raises(RuntimeError): + ddbc.ThrowStdException("Test exception message") + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestDataStructures: + """Test C++ data structures exposed to Python.""" + + def test_param_info_creation(self): + """Test ParamInfo structure creation and access.""" + param = ddbc.ParamInfo() + + # Test that object was created successfully + assert param is not None + + # Test basic attributes that should be accessible + try: + param.inputOutputType = 1 + assert param.inputOutputType == 1 + except (AttributeError, TypeError): + # Some attributes might not be directly accessible + pass + + try: + param.paramCType = 2 + assert param.paramCType == 2 + except (AttributeError, TypeError): + pass + + try: + param.paramSQLType = 3 + assert param.paramSQLType == 3 + except (AttributeError, TypeError): + pass + + # Test that the object has the expected type + assert str(type(param)) == "" + + def test_numeric_data_creation(self): + """Test NumericData structure creation and manipulation.""" + # Test default constructor + num1 = ddbc.NumericData() + assert hasattr(num1, 'precision') + assert hasattr(num1, 'scale') + assert hasattr(num1, 'sign') + assert hasattr(num1, 'val') + + # Test parameterized constructor + test_bytes = b'\\x12\\x34\\x00\\x00' # Sample binary data + num2 = ddbc.NumericData(18, 2, 1, test_bytes.decode('latin-1')) + + assert num2.precision == 18 + assert num2.scale == 2 + assert num2.sign == 1 + assert len(num2.val) == 16 # SQL_MAX_NUMERIC_LEN + + # Test setting values + num1.precision = 10 + num1.scale = 3 + num1.sign = 0 + + assert num1.precision == 10 + assert num1.scale == 3 + assert num1.sign == 0 + + def test_error_info_structure(self): + """Test ErrorInfo structure.""" + # ErrorInfo might not have a default constructor, so just test that the class exists + assert hasattr(ddbc, 'ErrorInfo') + + # Test that it's a valid class type + ErrorInfoClass = getattr(ddbc, 'ErrorInfo') + assert callable(ErrorInfoClass) or hasattr(ErrorInfoClass, '__name__') + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestConnectionFunctions: + """Test connection-related pybind functions.""" + + @pytest.fixture + def db_connection(self): + """Provide a database connection for testing.""" + try: + conn_str = os.getenv("DB_CONNECTION_STRING") + conn = ddbc.Connection(conn_str, False, {}) + yield conn + try: + conn.close() + except: + pass + except Exception: + pytest.skip("Database connection not available for testing") + + def test_connection_creation(self): + """Test Connection class creation.""" + try: + conn_str = os.getenv("DB_CONNECTION_STRING") + conn = ddbc.Connection(conn_str, False, {}) + + assert conn is not None + + # Test basic methods exist + assert hasattr(conn, 'close') + assert hasattr(conn, 'commit') + assert hasattr(conn, 'rollback') + assert hasattr(conn, 'set_autocommit') + assert hasattr(conn, 'get_autocommit') + assert hasattr(conn, 'alloc_statement_handle') + + conn.close() + + except Exception as e: + if "driver not found" in str(e).lower(): + pytest.skip(f"ODBC driver not available: {e}") + else: + raise + + def test_connection_with_attrs_before(self): + """Test Connection creation with attrs_before parameter.""" + try: + conn_str = os.getenv("DB_CONNECTION_STRING") + attrs = {"SQL_ATTR_CONNECTION_TIMEOUT": 30} + conn = ddbc.Connection(conn_str, False, attrs) + + assert conn is not None + conn.close() + + except Exception as e: + if "driver not found" in str(e).lower(): + pytest.skip(f"ODBC driver not available: {e}") + else: + raise + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestPoolingFunctions: + """Test connection pooling functionality.""" + + def test_enable_pooling(self): + """Test enabling connection pooling.""" + try: + ddbc.enable_pooling() + # Should not raise an exception + except Exception as e: + # Some environments might not support pooling + assert "pooling" in str(e).lower() or "not supported" in str(e).lower() + + def test_close_pooling(self): + """Test closing connection pools.""" + try: + ddbc.close_pooling() + # Should not raise an exception + except Exception as e: + # Acceptable if pooling wasn't enabled + pass + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestSQLFunctions: + """Test SQL execution functions.""" + + @pytest.fixture + def statement_handle(self, db_connection): + """Provide a statement handle for testing.""" + try: + stmt = db_connection.alloc_statement_handle() + yield stmt + try: + ddbc.DDBCSQLFreeHandle(2, stmt) # SQL_HANDLE_STMT = 2 + except: + pass + except Exception: + pytest.skip("Cannot create statement handle") + + def test_sql_exec_direct_simple(self, statement_handle): + """Test DDBCSQLExecDirect with a simple query.""" + try: + result = ddbc.DDBCSQLExecDirect(statement_handle, "SELECT 1 as test_col") + # SQL_SUCCESS = 0, SQL_SUCCESS_WITH_INFO = 1 + assert result in [0, 1] + except Exception as e: + if "connection" in str(e).lower(): + pytest.skip(f"Database connection issue: {e}") + else: + raise + + def test_sql_num_result_cols(self, statement_handle): + """Test DDBCSQLNumResultCols function.""" + try: + # First execute a query + ddbc.DDBCSQLExecDirect(statement_handle, "SELECT 1 as col1, 'test' as col2") + + # Then get number of columns + num_cols = ddbc.DDBCSQLNumResultCols(statement_handle) + assert num_cols == 2 + + except Exception as e: + if "connection" in str(e).lower(): + pytest.skip(f"Database connection issue: {e}") + else: + raise + + def test_sql_describe_col(self, statement_handle): + """Test DDBCSQLDescribeCol function.""" + try: + # Execute a query first + ddbc.DDBCSQLExecDirect(statement_handle, "SELECT 'test' as test_column") + + # Describe the first column + col_info = ddbc.DDBCSQLDescribeCol(statement_handle, 1) + + assert isinstance(col_info, tuple) + assert len(col_info) >= 6 # Should return column name, type, etc. + + except Exception as e: + if "connection" in str(e).lower(): + pytest.skip(f"Database connection issue: {e}") + else: + raise + + def test_sql_fetch(self, statement_handle): + """Test DDBCSQLFetch function.""" + try: + # Execute a query + ddbc.DDBCSQLExecDirect(statement_handle, "SELECT 1") + + # Fetch the row + result = ddbc.DDBCSQLFetch(statement_handle) + # SQL_SUCCESS = 0, SQL_NO_DATA = 100 + assert result in [0, 100] + + except Exception as e: + if "connection" in str(e).lower(): + pytest.skip(f"Database connection issue: {e}") + else: + raise + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestErrorHandling: + """Test error handling functions.""" + + def test_sql_check_error_type_validation(self): + """Test DDBCSQLCheckError input validation.""" + # Test that function exists and can handle type errors gracefully + assert hasattr(ddbc, 'DDBCSQLCheckError') + + # Test with obviously wrong parameter types to check input validation + with pytest.raises((TypeError, AttributeError)): + ddbc.DDBCSQLCheckError("invalid", "invalid", "invalid") + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestDecimalSeparator: + """Test decimal separator functionality.""" + + def test_set_decimal_separator(self): + """Test DDBCSetDecimalSeparator function.""" + try: + # Test setting different decimal separators + ddbc.DDBCSetDecimalSeparator(".") + ddbc.DDBCSetDecimalSeparator(",") + + # Should not raise exceptions for valid separators + except Exception as e: + # Some implementations might not support this + assert "not supported" in str(e).lower() or "invalid" in str(e).lower() + + +@pytest.mark.skipif(platform.system() not in ['Linux', 'Darwin'], + reason="Unix-specific tests only run on Linux/macOS") +class TestUnixSpecificFunctions: + """Test Unix-specific functionality when available.""" + + def test_unix_utils_availability(self): + """Test that Unix utils are available on Unix systems.""" + # These functions are in unix_utils.h/cpp and should be available + # through the pybind module on Unix systems + + # Check if any Unix-specific functionality is exposed + # This tests that the conditional compilation worked correctly + module_attrs = dir(ddbc) + + # The module should at least have the basic functions + assert 'GetDriverPathCpp' in module_attrs + assert 'Connection' in module_attrs + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestThreadSafety: + """Test thread safety of pybind functions.""" + + def test_concurrent_driver_path_access(self): + """Test concurrent access to GetDriverPathCpp.""" + results = [] + exceptions = [] + + def get_driver_path(): + try: + path = ddbc.GetDriverPathCpp() + results.append(path) + except Exception as e: + exceptions.append(e) + + threads = [] + for _ in range(5): + thread = threading.Thread(target=get_driver_path) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # Either all should succeed with same result, or all should fail consistently + if results: + # All successful results should be the same + assert all(r == results[0] for r in results) + + # Should not have mixed success/failure without consistent error types + if exceptions and results: + # This would indicate a thread safety issue + pytest.fail("Mixed success/failure in concurrent access suggests thread safety issue") + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestMemoryManagement: + """Test memory management in pybind functions.""" + + def test_multiple_param_info_creation(self): + """Test creating multiple ParamInfo objects.""" + params = [] + for i in range(100): + param = ddbc.ParamInfo() + param.inputOutputType = i + param.dataPtr = f"data_{i}" + params.append(param) + + # Verify all objects maintain their data correctly + for i, param in enumerate(params): + assert param.inputOutputType == i + assert param.dataPtr == f"data_{i}" + + def test_multiple_numeric_data_creation(self): + """Test creating multiple NumericData objects.""" + numerics = [] + for i in range(50): + numeric = ddbc.NumericData(10 + i, 2, 1, f"test_{i}".encode('latin-1').decode('latin-1')) + numerics.append(numeric) + + # Verify all objects maintain their data correctly + for i, numeric in enumerate(numerics): + assert numeric.precision == 10 + i + assert numeric.scale == 2 + assert numeric.sign == 1 + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_numeric_data_max_length(self): + """Test NumericData with maximum length value.""" + # SQL_MAX_NUMERIC_LEN is 16 + max_data = b'\\x00' * 16 + try: + numeric = ddbc.NumericData(38, 0, 1, max_data.decode('latin-1')) + assert len(numeric.val) == 16 + except Exception as e: + # Should either work or give a clear error about length + assert "length" in str(e).lower() or "size" in str(e).lower() + + def test_numeric_data_oversized_value(self): + """Test NumericData with oversized value.""" + oversized_data = b'\\x00' * 20 # Larger than SQL_MAX_NUMERIC_LEN + with pytest.raises((RuntimeError, ValueError)): + ddbc.NumericData(38, 0, 1, oversized_data.decode('latin-1')) + + def test_param_info_extreme_values(self): + """Test ParamInfo with extreme values.""" + param = ddbc.ParamInfo() + + # Test with very large values + param.columnSize = 2**31 - 1 # Max SQLULEN + param.strLenOrInd = -(2**31) # Min SQLLEN + + assert param.columnSize == 2**31 - 1 + assert param.strLenOrInd == -(2**31) + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestAdditionalPybindFunctions: + """Test additional pybind functions to increase coverage.""" + + def test_all_exposed_functions_exist(self): + """Test that all expected C++ functions are exposed.""" + expected_functions = [ + 'GetDriverPathCpp', 'ThrowStdException', 'enable_pooling', 'close_pooling', + 'DDBCSetDecimalSeparator', 'DDBCSQLExecDirect', 'DDBCSQLExecute', + 'DDBCSQLRowCount', 'DDBCSQLFetch', 'DDBCSQLNumResultCols', 'DDBCSQLDescribeCol', + 'DDBCSQLGetData', 'DDBCSQLMoreResults', 'DDBCSQLFetchOne', 'DDBCSQLFetchMany', + 'DDBCSQLFetchAll', 'DDBCSQLFreeHandle', 'DDBCSQLCheckError', 'DDBCSQLTables', + 'DDBCSQLFetchScroll', 'DDBCSQLSetStmtAttr', 'DDBCSQLGetTypeInfo' + ] + + for func_name in expected_functions: + assert hasattr(ddbc, func_name), f"Function {func_name} not found in ddbc_bindings" + func = getattr(ddbc, func_name) + assert callable(func), f"{func_name} is not callable" + + def test_all_exposed_classes_exist(self): + """Test that all expected C++ classes are exposed.""" + expected_classes = ['ParamInfo', 'NumericData', 'ErrorInfo', 'SqlHandle', 'Connection'] + + for class_name in expected_classes: + assert hasattr(ddbc, class_name), f"Class {class_name} not found in ddbc_bindings" + cls = getattr(ddbc, class_name) + # Check that it's a class/type + assert hasattr(cls, '__name__') or str(type(cls)).find('class') != -1 + + def test_numeric_data_with_various_inputs(self): + """Test NumericData with various input combinations.""" + # Test different precision and scale combinations + test_cases = [ + (10, 0, 1, b'\\x12\\x34'), + (18, 2, 0, b'\\x00\\x01'), + (38, 10, 1, b'\\xFF\\xEE\\xDD'), + ] + + for precision, scale, sign, data in test_cases: + try: + numeric = ddbc.NumericData(precision, scale, sign, data.decode('latin-1')) + assert numeric.precision == precision + assert numeric.scale == scale + assert numeric.sign == sign + assert len(numeric.val) == 16 # SQL_MAX_NUMERIC_LEN + except Exception as e: + # Some combinations might not be valid, which is acceptable + assert "length" in str(e).lower() or "size" in str(e).lower() or "runtime" in str(e).lower() + + def test_connection_pooling_workflow(self): + """Test the complete connection pooling workflow.""" + try: + # Test enabling pooling multiple times (should be safe) + ddbc.enable_pooling() + ddbc.enable_pooling() + + # Test closing pools + ddbc.close_pooling() + ddbc.close_pooling() # Should be safe to call multiple times + + except Exception as e: + # Pooling might not be supported in all environments + error_msg = str(e).lower() + assert any(keyword in error_msg for keyword in [ + "not supported", "not available", "pooling" + ]) + + def test_decimal_separator_variations(self): + """Test decimal separator with different inputs.""" + separators_to_test = [".", ",", ";"] + + for sep in separators_to_test: + try: + ddbc.DDBCSetDecimalSeparator(sep) + # If successful, test that we can set it back + ddbc.DDBCSetDecimalSeparator(".") + except Exception as e: + # Some separators might not be supported + error_msg = str(e).lower() + assert any(keyword in error_msg for keyword in [ + "invalid", "not supported", "separator" + ]) + + def test_driver_path_with_different_drivers(self): + """Test GetDriverPathCpp with different driver names.""" + driver_names = [ + "ODBC Driver 18 for SQL Server", + "ODBC Driver 17 for SQL Server", + "SQL Server", + "NonExistentDriver" + ] + + for driver_name in driver_names: + try: + path = ddbc.GetDriverPathCpp(driver_name) + if path: # If a path is returned + assert isinstance(path, str) + assert len(path) > 0 + except Exception as e: + # Driver not found is acceptable + error_msg = str(e).lower() + assert any(keyword in error_msg for keyword in [ + "not found", "cannot find", "not available", "driver" + ]) + + def test_function_signature_validation(self): + """Test that functions properly validate their input parameters.""" + + # Test ThrowStdException with different message types + test_messages = ["Test message", "", "Unicode: こんにちは"] + for msg in test_messages: + with pytest.raises(RuntimeError): + ddbc.ThrowStdException(msg) + + # Test parameter validation for other functions + with pytest.raises(TypeError): + ddbc.DDBCSetDecimalSeparator(123) # Should be string + + with pytest.raises(TypeError): + ddbc.GetDriverPathCpp(None) # Should be string + + +@pytest.mark.skipif(not DDBC_AVAILABLE, reason="ddbc_bindings not available") +class TestPybindErrorScenarios: + """Test error scenarios and edge cases in pybind functions.""" + + def test_invalid_parameter_types(self): + """Test functions with invalid parameter types.""" + + # Test various functions with wrong parameter types + test_cases = [ + (ddbc.GetDriverPathCpp, [None, 123, []]), + (ddbc.ThrowStdException, [None, 123, []]), + (ddbc.DDBCSetDecimalSeparator, [None, 123, []]), + ] + + for func, invalid_params in test_cases: + for param in invalid_params: + with pytest.raises(TypeError): + func(param) + + def test_boundary_conditions(self): + """Test functions with boundary condition inputs.""" + + # Test with very long strings + long_string = "A" * 10000 + try: + ddbc.ThrowStdException(long_string) + assert False, "Should have raised RuntimeError" + except RuntimeError: + pass # Expected + except Exception as e: + # Might fail with different error for very long strings + assert "length" in str(e).lower() or "size" in str(e).lower() + + # Test with empty string + with pytest.raises(RuntimeError): + ddbc.ThrowStdException("") + + def test_unicode_handling(self): + """Test Unicode string handling in pybind functions.""" + + unicode_strings = [ + "Hello, 世界", # Chinese + "Привет, мир", # Russian + "مرحبا بالعالم", # Arabic + "🌍🌎🌏", # Emojis + ] + + for unicode_str in unicode_strings: + try: + with pytest.raises(RuntimeError): + ddbc.ThrowStdException(unicode_str) + except UnicodeError: + # Some Unicode might not be handled properly, which is acceptable + pass + + try: + ddbc.GetDriverPathCpp(unicode_str) + # Might succeed or fail depending on system + except Exception: + # Unicode driver names likely don't exist + pass + + +if __name__ == "__main__": + # Run tests when executed directly + pytest.main([__file__, "-v"]) \ No newline at end of file