-
Notifications
You must be signed in to change notification settings - Fork 176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: pass scale to DF round in spark_round #1341
Changes from 6 commits
1fb4917
3f7674b
13e5723
d9c0ca9
ddccb17
16b37db
6b1fc3c
ab5379b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,9 +85,10 @@ pub fn spark_round( | |
let (precision, scale) = get_precision_scale(data_type); | ||
make_decimal_array(array, precision, scale, &f) | ||
} | ||
DataType::Float32 | DataType::Float64 => { | ||
Ok(ColumnarValue::Array(round(&[Arc::clone(array)])?)) | ||
} | ||
DataType::Float32 | DataType::Float64 => Ok(ColumnarValue::Array(round(&[ | ||
Arc::clone(array), | ||
args[1].to_array(array.len())?, | ||
])?)), | ||
dt => exec_err!("Not supported datatype for ROUND: {dt}"), | ||
}, | ||
ColumnarValue::Scalar(a) => match a { | ||
|
@@ -109,7 +110,7 @@ pub fn spark_round( | |
make_decimal_scalar(a, precision, scale, &f) | ||
} | ||
ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar( | ||
ScalarValue::try_from_array(&round(&[a.to_array()?])?, 0)?, | ||
ScalarValue::try_from_array(&round(&[a.to_array()?, args[1].to_array(1)?])?, 0)?, | ||
)), | ||
dt => exec_err!("Not supported datatype for ROUND: {dt}"), | ||
}, | ||
|
@@ -135,3 +136,50 @@ fn decimal_round_f(scale: &i8, point: &i64) -> Box<dyn Fn(i128) -> i128> { | |
Box::new(move |x: i128| (x + x.signum() * half) / div) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use std::sync::Arc; | ||
|
||
use crate::spark_round; | ||
|
||
use arrow::array::{Float32Array, Float64Array}; | ||
use arrow_schema::DataType; | ||
use datafusion_common::cast::{as_float32_array, as_float64_array}; | ||
use datafusion_common::{Result, ScalarValue}; | ||
use datafusion_expr::ColumnarValue; | ||
|
||
#[test] | ||
fn test_round_f32() -> Result<()> { | ||
let args = vec![ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we should have scalar test cases as well as random value tests and ideally we need expected values from Spark results There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we get it merged ? or do you want more tests ? also how would we use spark results in test ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @cht42 merged. |
||
ColumnarValue::Array(Arc::new(Float32Array::from(vec![ | ||
125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123, | ||
]))), | ||
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), | ||
]; | ||
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32)? else { | ||
unreachable!() | ||
}; | ||
let floats = as_float32_array(&result)?; | ||
let expected = Float32Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]); | ||
assert_eq!(floats, &expected); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_round_f64() -> Result<()> { | ||
let args = vec![ | ||
ColumnarValue::Array(Arc::new(Float64Array::from(vec![ | ||
125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123, | ||
]))), | ||
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), | ||
]; | ||
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64)? else { | ||
unreachable!() | ||
}; | ||
let floats = as_float64_array(&result)?; | ||
let expected = Float64Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]); | ||
assert_eq!(floats, &expected); | ||
Ok(()) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI round for float is disabled https://github.com/apache/datafusion-comet/blob/main/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala#L1739
BTW we should be able to use
point
instead ofargs[1]
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
point is being set to the scalar value on this line and we want to pass the columnar value to DF
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in my use case, im not using comet's query planning so it should be fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this isn't being tested end-to-end, could you add a Rust unit test so that we protect against regressions in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still for the reason mentioned in the link, it may not work some cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added tests for floats