Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def _get_numeric_data(self, param):
the numeric data.
"""
decimal_as_tuple = param.as_tuple()
num_digits = len(decimal_as_tuple.digits)
digits_tuple = decimal_as_tuple.digits
num_digits = len(digits_tuple)
exponent = decimal_as_tuple.exponent

# Calculate the SQL precision & scale
Expand All @@ -215,12 +216,11 @@ def _get_numeric_data(self, param):
precision = exponent * -1
scale = exponent * -1

# TODO: Revisit this check, do we want this restriction?
if precision > 15:
if precision > 38:
raise ValueError(
"Precision of the numeric value is too high - "
+ str(param)
+ ". Should be less than or equal to 15"
+ ". Should be less than or equal to 38"
)
Numeric_Data = ddbc_bindings.NumericData
numeric_data = Numeric_Data()
Expand All @@ -229,12 +229,26 @@ def _get_numeric_data(self, param):
numeric_data.sign = 1 if decimal_as_tuple.sign == 0 else 0
# strip decimal point from param & convert the significant digits to integer
# Ex: 12.34 ---> 1234
val = str(param)
if "." in val or "-" in val:
val = val.replace(".", "")
val = val.replace("-", "")
val = int(val)
numeric_data.val = val
int_str = ''.join(str(d) for d in digits_tuple)
if exponent > 0:
int_str = int_str + ('0' * exponent)
elif exponent < 0:
if -exponent > num_digits:
int_str = ('0' * (-exponent - num_digits)) + int_str

if int_str == '':
int_str = '0'

# Convert decimal base-10 string to python int, then to 16 little-endian bytes
big_int = int(int_str)
byte_array = bytearray(16) # SQL_MAX_NUMERIC_LEN
for i in range(16):
byte_array[i] = big_int & 0xFF
big_int >>= 8
if big_int == 0:
break

numeric_data.val = bytes(byte_array)
return numeric_data

def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
Expand Down Expand Up @@ -307,7 +321,27 @@ def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
)

if isinstance(param, decimal.Decimal):
# Detect MONEY / SMALLMONEY range
# First check precision limit for all decimal values
decimal_as_tuple = param.as_tuple()
digits_tuple = decimal_as_tuple.digits
num_digits = len(digits_tuple)
exponent = decimal_as_tuple.exponent

# Calculate the SQL precision (same logic as _get_numeric_data)
if exponent >= 0:
precision = num_digits + exponent
elif (-1 * exponent) <= num_digits:
precision = num_digits
else:
precision = exponent * -1

if precision > 38:
raise ValueError(
f"Precision of the numeric value is too high. "
f"The maximum precision supported by SQL Server is 38, but got {precision}."
)

# Detect MONEY / SMALLMONEY range
if SMALLMONEY_MIN <= param <= SMALLMONEY_MAX:
# smallmoney
parameters_list[i] = str(param)
Expand Down
44 changes: 27 additions & 17 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#define SQL_SS_TIMESTAMPOFFSET (-155)
#define SQL_C_SS_TIMESTAMPOFFSET (0x4001)
#define MAX_DIGITS_IN_NUMERIC 64
#define SQL_MAX_NUMERIC_LEN 16
#define SQL_SS_XML (-152)

#define STRINGIFY_FOR_CASE(x) \
Expand Down Expand Up @@ -57,12 +58,18 @@ struct NumericData {
SQLCHAR precision;
SQLSCHAR scale;
SQLCHAR sign; // 1=pos, 0=neg
std::uint64_t val; // 123.45 -> 12345
std::string val; // 123.45 -> 12345

NumericData() : precision(0), scale(0), sign(0), val(0) {}
NumericData() : precision(0), scale(0), sign(0), val(SQL_MAX_NUMERIC_LEN, '\0') {}

NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, std::uint64_t value)
: precision(precision), scale(scale), sign(sign), val(value) {}
NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, const std::string& valueBytes)
: precision(precision), scale(scale), sign(sign), val(SQL_MAX_NUMERIC_LEN, '\0') {
if (valueBytes.size() > SQL_MAX_NUMERIC_LEN) {
throw std::runtime_error("NumericData valueBytes size exceeds SQL_MAX_NUMERIC_LEN (16)");
}
// Copy binary data to buffer, remaining bytes stay zero-padded
std::memcpy(&val[0], valueBytes.data(), valueBytes.size());
}
};

// Struct to hold the DateTimeOffset structure
Expand Down Expand Up @@ -558,9 +565,10 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
decimalPtr->sign = decimalParam.sign;
// Convert the integer decimalParam.val to char array
std::memset(static_cast<void*>(decimalPtr->val), 0, sizeof(decimalPtr->val));
std::memcpy(static_cast<void*>(decimalPtr->val),
reinterpret_cast<char*>(&decimalParam.val),
sizeof(decimalParam.val));
size_t copyLen = std::min(decimalParam.val.size(), sizeof(decimalPtr->val));
if (copyLen > 0) {
std::memcpy(decimalPtr->val, decimalParam.val.data(), copyLen);
}
dataPtr = static_cast<void*>(decimalPtr);
break;
}
Expand Down Expand Up @@ -2051,15 +2059,17 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt,
throw std::runtime_error(MakeParamMismatchErrorStr(info.paramCType, paramIndex));
}
NumericData decimalParam = element.cast<NumericData>();
LOG("Received numeric parameter at [%zu]: precision=%d, scale=%d, sign=%d, val=%lld",
i, decimalParam.precision, decimalParam.scale, decimalParam.sign, decimalParam.val);
numericArray[i].precision = decimalParam.precision;
numericArray[i].scale = decimalParam.scale;
numericArray[i].sign = decimalParam.sign;
std::memset(numericArray[i].val, 0, sizeof(numericArray[i].val));
std::memcpy(numericArray[i].val,
reinterpret_cast<const char*>(&decimalParam.val),
std::min(sizeof(decimalParam.val), sizeof(numericArray[i].val)));
LOG("Received numeric parameter at [%zu]: precision=%d, scale=%d, sign=%d, val=%s",
i, decimalParam.precision, decimalParam.scale, decimalParam.sign, decimalParam.val.c_str());
SQL_NUMERIC_STRUCT& target = numericArray[i];
std::memset(&target, 0, sizeof(SQL_NUMERIC_STRUCT));
target.precision = decimalParam.precision;
target.scale = decimalParam.scale;
target.sign = decimalParam.sign;
size_t copyLen = std::min(decimalParam.val.size(), sizeof(target.val));
if (copyLen > 0) {
std::memcpy(target.val, decimalParam.val.data(), copyLen);
}
strLenOrIndArray[i] = sizeof(SQL_NUMERIC_STRUCT);
}
dataPtr = numericArray;
Expand Down Expand Up @@ -3800,7 +3810,7 @@ PYBIND11_MODULE(ddbc_bindings, m) {
// Define numeric data class
py::class_<NumericData>(m, "NumericData")
.def(py::init<>())
.def(py::init<SQLCHAR, SQLSCHAR, SQLCHAR, std::uint64_t>())
.def(py::init<SQLCHAR, SQLSCHAR, SQLCHAR, const std::string&>())
.def_readwrite("precision", &NumericData::precision)
.def_readwrite("scale", &NumericData::scale)
.def_readwrite("sign", &NumericData::sign)
Expand Down
Loading