Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,49 @@ def __init__(self, connection, timeout: int = 0) -> None:

self.messages = [] # Store diagnostic messages

def _get_encoding_settings(self):
"""
Get the encoding settings from the connection.

Returns:
dict: A dictionary with 'encoding' and 'ctype' keys, or default settings if not available
"""
if hasattr(self._connection, 'getencoding'):
try:
return self._connection.getencoding()
except:
# Return default encoding settings if there's an error
return {
'encoding': 'utf-16le',
'ctype': ddbc_sql_const.SQL_WCHAR.value
}
# Return default encoding settings if getencoding is not available
return {
'encoding': 'utf-16le',
'ctype': ddbc_sql_const.SQL_WCHAR.value
}

def _get_decoding_settings(self, sql_type):
"""
Get decoding settings for a specific SQL type.

Args:
sql_type: SQL type constant (SQL_CHAR, SQL_WCHAR, etc.)

Returns:
Dictionary containing the decoding settings.
"""
try:
# Get decoding settings from connection for this SQL type
return self._connection.getdecoding(sql_type)
except Exception as e:
# If anything goes wrong, return default settings
log('warning', f"Failed to get decoding settings for SQL type {sql_type}: {e}")
if sql_type == SQL_WCHAR:
return {'encoding': 'utf-16le', 'ctype': SQL_WCHAR}
else:
return {'encoding': 'utf-8', 'ctype': SQL_CHAR}

def _is_unicode_string(self, param):
"""
Check if a string contains non-ASCII characters.
Expand Down Expand Up @@ -966,6 +1009,8 @@ def execute(
parameters_type[i].decimalDigits,
parameters_type[i].inputOutputType,
)

encoding_settings = self._get_encoding_settings()

ret = ddbc_bindings.DDBCSQLExecute(
self.hstmt,
Expand All @@ -974,6 +1019,8 @@ def execute(
parameters_type,
self.is_stmt_prepared,
use_prepare,
encoding_settings.get('encoding'),
encoding_settings.get('ctype')
)
# Check return code
try:
Expand Down Expand Up @@ -1666,12 +1713,16 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
len(seq_of_parameters), "\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters[:5])) # Limit to first 5 rows for large batches
)

encoding_settings = self._get_encoding_settings()

ret = ddbc_bindings.SQLExecuteMany(
self.hstmt,
operation,
columnwise_params,
parameters_type,
row_count
row_count,
encoding_settings.get('encoding'),
encoding_settings.get('ctype')
)

# Capture any diagnostic messages after execution
Expand Down Expand Up @@ -1703,10 +1754,14 @@ def fetchone(self) -> Union[None, Row]:
"""
self._check_closed() # Check if the cursor is closed

# Get decoding settings for character data
char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
row_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data)
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
Expand Down Expand Up @@ -1753,11 +1808,16 @@ def fetchmany(self, size: int = None) -> List[Row]:

if size <= 0:
return []

# Get decoding settings for character data
char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
rows_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size)
ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))


if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
Expand Down Expand Up @@ -1793,10 +1853,14 @@ def fetchall(self) -> List[Row]:
if not self._has_result_set and self.description:
self._reset_rownumber()

# Get decoding settings for character data
char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
rows_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)
ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
Expand Down
Loading
Loading