Skip to content

Commit d5f4e13

Browse files
committed
fixes precision loss in numereic datatype
1 parent 3e54515 commit d5f4e13

File tree

2 files changed

+232
-86
lines changed

2 files changed

+232
-86
lines changed

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,7 @@ struct NumericData {
6262
NumericData() : precision(0), scale(0), sign(0), val(SQL_MAX_NUMERIC_LEN, '\0') {}
6363

6464
NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, const std::string& valueBytes)
65-
: precision(precision), scale(scale), sign(sign) {
66-
val = valueBytes;
67-
// Ensure val is always exactly SQL_MAX_NUMERIC_LEN bytes
68-
val.resize(SQL_MAX_NUMERIC_LEN, '\0');
69-
}
65+
: precision(precision), scale(scale), sign(sign), val(valueBytes) {}
7066
};
7167

7268
// Struct to hold the DateTimeOffset structure
@@ -562,21 +558,10 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
562558
decimalPtr->sign = decimalParam.sign;
563559
// Convert the integer decimalParam.val to char array
564560
std::memset(static_cast<void*>(decimalPtr->val), 0, sizeof(decimalPtr->val));
565-
// std::memcpy(static_cast<void*>(decimalPtr->val),
566-
// reinterpret_cast<char*>(&decimalParam.val),
567-
// sizeof(decimalParam.val));
568-
size_t src_len = decimalParam.val.size();
569-
if (src_len > sizeof(decimalPtr->val)) {
570-
// Defensive: should never happen if Python side ensures 16 bytes; but guard anyway
571-
ThrowStdException("Numeric value byte buffer too large for SQL_NUMERIC_STRUCT (paramIndex " + std::to_string(paramIndex) + ")");
572-
}
573-
if (src_len > 0) {
574-
std::memcpy(static_cast<void*>(decimalPtr->val),
575-
static_cast<const void*>(decimalParam.val.data()),
576-
src_len);
577-
}
578-
//print the data received from python
579-
LOG("Numeric parameter val bytes: {}", decimalPtr->val);
561+
size_t copyLen = std::min(decimalParam.val.size(), sizeof(decimalPtr->val));
562+
if (copyLen > 0) {
563+
std::memcpy(decimalPtr->val, decimalParam.val.data(), copyLen);
564+
}
580565
dataPtr = static_cast<void*>(decimalPtr);
581566
break;
582567
}
@@ -2069,13 +2054,15 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt,
20692054
NumericData decimalParam = element.cast<NumericData>();
20702055
LOG("Received numeric parameter at [%zu]: precision=%d, scale=%d, sign=%d, val=%lld",
20712056
i, decimalParam.precision, decimalParam.scale, decimalParam.sign, decimalParam.val);
2072-
numericArray[i].precision = decimalParam.precision;
2073-
numericArray[i].scale = decimalParam.scale;
2074-
numericArray[i].sign = decimalParam.sign;
2075-
std::memset(numericArray[i].val, 0, sizeof(numericArray[i].val));
2076-
std::memcpy(numericArray[i].val,
2077-
reinterpret_cast<const char*>(&decimalParam.val),
2078-
std::min(sizeof(decimalParam.val), sizeof(numericArray[i].val)));
2057+
SQL_NUMERIC_STRUCT& target = numericArray[i];
2058+
std::memset(&target, 0, sizeof(SQL_NUMERIC_STRUCT));
2059+
target.precision = decimalParam.precision;
2060+
target.scale = decimalParam.scale;
2061+
target.sign = decimalParam.sign;
2062+
size_t copyLen = std::min(decimalParam.val.size(), sizeof(target.val));
2063+
if (copyLen > 0) {
2064+
std::memcpy(target.val, decimalParam.val.data(), copyLen);
2065+
}
20792066
strLenOrIndArray[i] = sizeof(SQL_NUMERIC_STRUCT);
20802067
}
20812068
dataPtr = numericArray;

tests/test_004_cursor.py

Lines changed: 218 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,22 +1640,6 @@ def test_parse_datetime2(cursor, db_connection):
16401640
cursor.execute("DROP TABLE #pytest_datetime2_test")
16411641
db_connection.commit()
16421642

1643-
def test_get_numeric_data(cursor, db_connection):
1644-
"""Test _get_numeric_data"""
1645-
try:
1646-
cursor.execute("CREATE TABLE #pytest_numeric_test (numeric_column DECIMAL(10, 2))")
1647-
db_connection.commit()
1648-
cursor.execute("INSERT INTO #pytest_numeric_test (numeric_column) VALUES (?)", [decimal.Decimal('123.45')])
1649-
db_connection.commit()
1650-
cursor.execute("SELECT numeric_column FROM #pytest_numeric_test")
1651-
row = cursor.fetchone()
1652-
assert row[0] == decimal.Decimal('123.45'), "Numeric data parsing failed"
1653-
except Exception as e:
1654-
pytest.fail(f"Numeric data parsing test failed: {e}")
1655-
finally:
1656-
cursor.execute("DROP TABLE #pytest_numeric_test")
1657-
db_connection.commit()
1658-
16591643
def test_none(cursor, db_connection):
16601644
"""Test None"""
16611645
try:
@@ -1721,48 +1705,6 @@ def test_sql_varchar(cursor, db_connection):
17211705
cursor.execute("DROP TABLE #pytest_varchar_test")
17221706
db_connection.commit()
17231707

1724-
def test_numeric_precision_scale_positive_exponent(cursor, db_connection):
1725-
"""Test precision and scale for numeric values with positive exponent"""
1726-
try:
1727-
cursor.execute("CREATE TABLE #pytest_numeric_test (numeric_column DECIMAL(10, 2))")
1728-
db_connection.commit()
1729-
cursor.execute("INSERT INTO #pytest_numeric_test (numeric_column) VALUES (?)", [decimal.Decimal('31400')])
1730-
db_connection.commit()
1731-
cursor.execute("SELECT numeric_column FROM #pytest_numeric_test")
1732-
row = cursor.fetchone()
1733-
assert row[0] == decimal.Decimal('31400'), "Numeric data parsing failed"
1734-
# Check precision and scale
1735-
precision = 5 # 31400 has 5 significant digits
1736-
scale = 0 # No digits after the decimal point
1737-
assert precision == 5, "Precision calculation failed"
1738-
assert scale == 0, "Scale calculation failed"
1739-
except Exception as e:
1740-
pytest.fail(f"Numeric precision and scale test failed: {e}")
1741-
finally:
1742-
cursor.execute("DROP TABLE #pytest_numeric_test")
1743-
db_connection.commit()
1744-
1745-
def test_numeric_precision_scale_negative_exponent(cursor, db_connection):
1746-
"""Test precision and scale for numeric values with negative exponent"""
1747-
try:
1748-
cursor.execute("CREATE TABLE #pytest_numeric_test (numeric_column DECIMAL(10, 5))")
1749-
db_connection.commit()
1750-
cursor.execute("INSERT INTO #pytest_numeric_test (numeric_column) VALUES (?)", [decimal.Decimal('0.03140')])
1751-
db_connection.commit()
1752-
cursor.execute("SELECT numeric_column FROM #pytest_numeric_test")
1753-
row = cursor.fetchone()
1754-
assert row[0] == decimal.Decimal('0.03140'), "Numeric data parsing failed"
1755-
# Check precision and scale
1756-
precision = 5 # 0.03140 has 5 significant digits
1757-
scale = 5 # 5 digits after the decimal point
1758-
assert precision == 5, "Precision calculation failed"
1759-
assert scale == 5, "Scale calculation failed"
1760-
except Exception as e:
1761-
pytest.fail(f"Numeric precision and scale test failed: {e}")
1762-
finally:
1763-
cursor.execute("DROP TABLE #pytest_numeric_test")
1764-
db_connection.commit()
1765-
17661708
def test_row_attribute_access(cursor, db_connection):
17671709
"""Test accessing row values by column name as attributes"""
17681710
try:
@@ -11402,7 +11344,224 @@ def test_datetime_string_parameter_binding(cursor, db_connection):
1140211344
finally:
1140311345
drop_table_if_exists(cursor, table_name)
1140411346
db_connection.commit()
11405-
11347+
11348+
# ---------------------------------------------------------
11349+
# Test 1: Basic numeric insertion and fetch roundtrip
11350+
# ---------------------------------------------------------
11351+
@pytest.mark.parametrize("precision, scale, value", [
11352+
(10, 2, decimal.Decimal("12345.67")),
11353+
(10, 4, decimal.Decimal("12.3456")),
11354+
(10, 0, decimal.Decimal("1234567890")),
11355+
])
11356+
def test_numeric_basic_roundtrip(cursor, db_connection, precision, scale, value):
11357+
"""Verify simple numeric values roundtrip correctly"""
11358+
table_name = f"#pytest_numeric_basic_{precision}_{scale}"
11359+
try:
11360+
cursor.execute(f"CREATE TABLE {table_name} (val NUMERIC({precision}, {scale}))")
11361+
cursor.execute(f"INSERT INTO {table_name} (val) VALUES (?)", (value,))
11362+
db_connection.commit()
11363+
11364+
cursor.execute(f"SELECT val FROM {table_name}")
11365+
row = cursor.fetchone()
11366+
assert row is not None, "Expected one row to be returned"
11367+
fetched = row[0]
11368+
11369+
expected = value.quantize(decimal.Decimal(f"1e-{scale}")) if scale > 0 else value
11370+
assert fetched == expected, f"Expected {expected}, got {fetched}"
11371+
11372+
finally:
11373+
cursor.execute(f"DROP TABLE {table_name}")
11374+
db_connection.commit()
11375+
11376+
# ---------------------------------------------------------
11377+
# Test 2: High precision numeric values (near SQL Server max)
11378+
# ---------------------------------------------------------
11379+
@pytest.mark.parametrize("value", [
11380+
decimal.Decimal("99999999999999999999999999999999999999"), # 38 digits
11381+
decimal.Decimal("12345678901234567890.1234567890"), # high precision
11382+
])
11383+
def test_numeric_high_precision_roundtrip(cursor, db_connection, value):
11384+
"""Verify high-precision NUMERIC values roundtrip without precision loss"""
11385+
precision, scale = 38, max(0, -value.as_tuple().exponent)
11386+
table_name = "#pytest_numeric_high_precision"
11387+
try:
11388+
cursor.execute(f"CREATE TABLE {table_name} (val NUMERIC({precision}, {scale}))")
11389+
cursor.execute(f"INSERT INTO {table_name} (val) VALUES (?)", (value,))
11390+
db_connection.commit()
11391+
11392+
cursor.execute(f"SELECT val FROM {table_name}")
11393+
row = cursor.fetchone()
11394+
assert row is not None
11395+
assert row[0] == value, f"High-precision roundtrip failed. Expected {value}, got {row[0]}"
11396+
11397+
finally:
11398+
cursor.execute(f"DROP TABLE {table_name}")
11399+
db_connection.commit()
11400+
11401+
# ---------------------------------------------------------
11402+
# Test 3: Negative, zero, and small fractional values
11403+
# ---------------------------------------------------------
11404+
@pytest.mark.parametrize("value", [
11405+
decimal.Decimal("-98765.43210"),
11406+
decimal.Decimal("-99999999999999999999.9999999999"),
11407+
decimal.Decimal("0"),
11408+
decimal.Decimal("0.00001"),
11409+
])
11410+
def test_numeric_negative_and_small_values(cursor, db_connection, value):
11411+
precision, scale = 38, max(0, -value.as_tuple().exponent)
11412+
table_name = "#pytest_numeric_neg_small"
11413+
try:
11414+
cursor.execute(f"CREATE TABLE {table_name} (val NUMERIC({precision}, {scale}))")
11415+
cursor.execute(f"INSERT INTO {table_name} (val) VALUES (?)", (value,))
11416+
db_connection.commit()
11417+
11418+
cursor.execute(f"SELECT val FROM {table_name}")
11419+
row = cursor.fetchone()
11420+
assert row[0] == value, f"Expected {value}, got {row[0]}"
11421+
11422+
finally:
11423+
cursor.execute(f"DROP TABLE {table_name}")
11424+
db_connection.commit()
11425+
11426+
# ---------------------------------------------------------
11427+
# Test 4: NULL handling and multiple inserts
11428+
# ---------------------------------------------------------
11429+
def test_numeric_null_and_multiple_rows(cursor, db_connection):
11430+
table_name = "#pytest_numeric_nulls"
11431+
try:
11432+
cursor.execute(f"CREATE TABLE {table_name} (val NUMERIC(20,5))")
11433+
11434+
values = [decimal.Decimal("123.45678"), None, decimal.Decimal("-999.99999")]
11435+
for v in values:
11436+
cursor.execute(f"INSERT INTO {table_name} (val) VALUES (?)", (v,))
11437+
db_connection.commit()
11438+
11439+
cursor.execute(f"SELECT val FROM {table_name} ORDER BY val ASC")
11440+
rows = [r[0] for r in cursor.fetchall()]
11441+
11442+
non_null_expected = sorted([v for v in values if v is not None])
11443+
non_null_actual = sorted([v for v in rows if v is not None])
11444+
11445+
assert non_null_actual == non_null_expected, f"Expected {non_null_expected}, got {non_null_actual}"
11446+
assert any(r is None for r in rows), "Expected one NULL value in result set"
11447+
11448+
finally:
11449+
cursor.execute(f"DROP TABLE {table_name}")
11450+
db_connection.commit()
11451+
11452+
# ---------------------------------------------------------
11453+
# Test 5: Boundary precision values (max precision / scale)
11454+
# ---------------------------------------------------------
11455+
def test_numeric_boundary_precision(cursor, db_connection):
11456+
table_name = "#pytest_numeric_boundary"
11457+
precision, scale = 38, 37
11458+
value = decimal.Decimal("0." + "9" * 37) # 0.999... up to 37 digits
11459+
try:
11460+
cursor.execute(f"CREATE TABLE {table_name} (val NUMERIC({precision},{scale}))")
11461+
cursor.execute(f"INSERT INTO {table_name} (val) VALUES (?)", (value,))
11462+
db_connection.commit()
11463+
11464+
cursor.execute(f"SELECT val FROM {table_name}")
11465+
row = cursor.fetchone()
11466+
assert row[0] == value, f"Boundary precision mismatch: expected {value}, got {row[0]}"
11467+
11468+
finally:
11469+
cursor.execute(f"DROP TABLE {table_name}")
11470+
db_connection.commit()
11471+
11472+
# ---------------------------------------------------------
11473+
# Test 6: Precision/scale positive exponent (corner case)
11474+
# ---------------------------------------------------------
11475+
def test_numeric_precision_scale_positive_exponent(cursor, db_connection):
11476+
try:
11477+
cursor.execute("CREATE TABLE #pytest_numeric_test (numeric_column DECIMAL(10, 2))")
11478+
db_connection.commit()
11479+
cursor.execute("INSERT INTO #pytest_numeric_test (numeric_column) VALUES (?)", [decimal.Decimal('31400')])
11480+
db_connection.commit()
11481+
cursor.execute("SELECT numeric_column FROM #pytest_numeric_test")
11482+
row = cursor.fetchone()
11483+
assert row[0] == decimal.Decimal('31400'), "Numeric data parsing failed"
11484+
11485+
precision = 5
11486+
scale = 0
11487+
assert precision == 5, "Precision calculation failed"
11488+
assert scale == 0, "Scale calculation failed"
11489+
11490+
finally:
11491+
cursor.execute("DROP TABLE #pytest_numeric_test")
11492+
db_connection.commit()
11493+
11494+
# ---------------------------------------------------------
11495+
# Test 7: Precision/scale negative exponent (corner case)
11496+
# ---------------------------------------------------------
11497+
def test_numeric_precision_scale_negative_exponent(cursor, db_connection):
11498+
try:
11499+
cursor.execute("CREATE TABLE #pytest_numeric_test (numeric_column DECIMAL(10, 5))")
11500+
db_connection.commit()
11501+
cursor.execute("INSERT INTO #pytest_numeric_test (numeric_column) VALUES (?)", [decimal.Decimal('0.03140')])
11502+
db_connection.commit()
11503+
cursor.execute("SELECT numeric_column FROM #pytest_numeric_test")
11504+
row = cursor.fetchone()
11505+
assert row[0] == decimal.Decimal('0.03140'), "Numeric data parsing failed"
11506+
11507+
precision = 5
11508+
scale = 5
11509+
assert precision == 5, "Precision calculation failed"
11510+
assert scale == 5, "Scale calculation failed"
11511+
11512+
finally:
11513+
cursor.execute("DROP TABLE #pytest_numeric_test")
11514+
db_connection.commit()
11515+
11516+
# ---------------------------------------------------------
11517+
# Test 8: fetchmany for numeric values
11518+
# ---------------------------------------------------------
11519+
@pytest.mark.parametrize("values", [[
11520+
decimal.Decimal("11.11"), decimal.Decimal("22.22"), decimal.Decimal("33.33")
11521+
]])
11522+
def test_numeric_fetchmany(cursor, db_connection, values):
11523+
table_name = "#pytest_numeric_fetchmany"
11524+
try:
11525+
cursor.execute(f"CREATE TABLE {table_name} (val NUMERIC(10,2))")
11526+
for v in values:
11527+
cursor.execute(f"INSERT INTO {table_name} (val) VALUES (?)", (v,))
11528+
db_connection.commit()
11529+
11530+
cursor.execute(f"SELECT val FROM {table_name} ORDER BY val")
11531+
rows1 = cursor.fetchmany(2)
11532+
rows2 = cursor.fetchmany(2)
11533+
all_rows = [r[0] for r in rows1 + rows2]
11534+
11535+
assert all_rows == sorted(values), f"fetchmany mismatch: expected {sorted(values)}, got {all_rows}"
11536+
11537+
finally:
11538+
cursor.execute(f"DROP TABLE {table_name}")
11539+
db_connection.commit()
11540+
11541+
# ---------------------------------------------------------
11542+
# Test 9: executemany for numeric values
11543+
# ---------------------------------------------------------
11544+
@pytest.mark.parametrize("values", [[
11545+
decimal.Decimal("111.1111"), decimal.Decimal("222.2222"), decimal.Decimal("333.3333"),
11546+
]])
11547+
def test_numeric_executemany(cursor, db_connection, values):
11548+
precision, scale = 38, 10
11549+
table_name = "#pytest_numeric_executemany"
11550+
try:
11551+
cursor.execute(f"CREATE TABLE {table_name} (val NUMERIC({precision},{scale}))")
11552+
11553+
params = [(v,) for v in values]
11554+
cursor.executemany(f"INSERT INTO {table_name} (val) VALUES (?)", params)
11555+
db_connection.commit()
11556+
11557+
cursor.execute(f"SELECT val FROM {table_name} ORDER BY val")
11558+
rows = [r[0] for r in cursor.fetchall()]
11559+
assert rows == sorted(values), f"executemany() mismatch: expected {sorted(values)}, got {rows}"
11560+
11561+
finally:
11562+
cursor.execute(f"DROP TABLE {table_name}")
11563+
db_connection.commit()
11564+
1140611565
def test_close(db_connection):
1140711566
"""Test closing the cursor"""
1140811567
try:

0 commit comments

Comments
 (0)