Skip to content

Commit 454db7e

Browse files
authored
Introduce Scalar type for ColumnarValue (#12536)
* Introduce `Scalar` type for ColumnarValue * Add constructor constraints for `Scalar` * Add rustdoc for `Scalar` * Add TODO note on `ColumnarValue::cast_to` * Add more `Scalar` rustdoc
1 parent 23d7fff commit 454db7e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+1300
-1036
lines changed

datafusion-examples/examples/advanced_udf.rs

+31-31
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ impl ScalarUDFImpl for PowUdf {
9696
// function, but we check again to make sure
9797
assert_eq!(args.len(), 2);
9898
let (base, exp) = (&args[0], &args[1]);
99-
assert_eq!(base.data_type(), DataType::Float64);
100-
assert_eq!(exp.data_type(), DataType::Float64);
99+
assert_eq!(base.data_type(), &DataType::Float64);
100+
assert_eq!(exp.data_type(), &DataType::Float64);
101101

102102
match (base, exp) {
103103
// For demonstration purposes we also implement the scalar / scalar
@@ -108,28 +108,31 @@ impl ScalarUDFImpl for PowUdf {
108108
// the DataFusion expression simplification logic will often invoke
109109
// this path once during planning, and simply use the result during
110110
// execution.
111-
(
112-
ColumnarValue::Scalar(ScalarValue::Float64(base)),
113-
ColumnarValue::Scalar(ScalarValue::Float64(exp)),
114-
) => {
115-
// compute the output. Note DataFusion treats `None` as NULL.
116-
let res = match (base, exp) {
117-
(Some(base), Some(exp)) => Some(base.powf(*exp)),
118-
// one or both arguments were NULL
119-
_ => None,
120-
};
121-
Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
111+
(ColumnarValue::Scalar(base), ColumnarValue::Scalar(exp)) => {
112+
match (base.value(), exp.value()) {
113+
(ScalarValue::Float64(base), ScalarValue::Float64(exp)) => {
114+
// compute the output. Note DataFusion treats `None` as NULL.
115+
let res = match (base, exp) {
116+
(Some(base), Some(exp)) => Some(base.powf(*exp)),
117+
// one or both arguments were NULL
118+
_ => None,
119+
};
120+
Ok(ColumnarValue::from(ScalarValue::from(res)))
121+
}
122+
_ => {
123+
internal_err!("Invalid argument types to pow function")
124+
}
125+
}
122126
}
123127
// special case if the exponent is a constant
124-
(
125-
ColumnarValue::Array(base_array),
126-
ColumnarValue::Scalar(ScalarValue::Float64(exp)),
127-
) => {
128-
let result_array = match exp {
128+
(ColumnarValue::Array(base_array), ColumnarValue::Scalar(exp)) => {
129+
let result_array = match exp.value() {
129130
// a ^ null = null
130-
None => new_null_array(base_array.data_type(), base_array.len()),
131+
ScalarValue::Float64(None) => {
132+
new_null_array(base_array.data_type(), base_array.len())
133+
}
131134
// a ^ exp
132-
Some(exp) => {
135+
ScalarValue::Float64(Some(exp)) => {
133136
// DataFusion has ensured both arguments are Float64:
134137
let base_array = base_array.as_primitive::<Float64Type>();
135138
// calculate the result for every row. The `unary`
@@ -139,24 +142,25 @@ impl ScalarUDFImpl for PowUdf {
139142
compute::unary(base_array, |base| base.powf(*exp));
140143
Arc::new(res)
141144
}
145+
_ => return internal_err!("Invalid argument types to pow function"),
142146
};
143147
Ok(ColumnarValue::Array(result_array))
144148
}
145149

146150
// special case if the base is a constant (note this code is quite
147151
// similar to the previous case, so we omit comments)
148-
(
149-
ColumnarValue::Scalar(ScalarValue::Float64(base)),
150-
ColumnarValue::Array(exp_array),
151-
) => {
152-
let res = match base {
153-
None => new_null_array(exp_array.data_type(), exp_array.len()),
154-
Some(base) => {
152+
(ColumnarValue::Scalar(base), ColumnarValue::Array(exp_array)) => {
153+
let res = match base.value() {
154+
ScalarValue::Float64(None) => {
155+
new_null_array(exp_array.data_type(), exp_array.len())
156+
}
157+
ScalarValue::Float64(Some(base)) => {
155158
let exp_array = exp_array.as_primitive::<Float64Type>();
156159
let res: Float64Array =
157160
compute::unary(exp_array, |exp| base.powf(exp));
158161
Arc::new(res)
159162
}
163+
_ => return internal_err!("Invalid argument types to pow function"),
160164
};
161165
Ok(ColumnarValue::Array(res))
162166
}
@@ -169,10 +173,6 @@ impl ScalarUDFImpl for PowUdf {
169173
)?;
170174
Ok(ColumnarValue::Array(Arc::new(res)))
171175
}
172-
// if the types were not float, it is a bug in DataFusion
173-
_ => {
174-
internal_err!("Invalid argument types to pow function")
175-
}
176176
}
177177
}
178178

datafusion-examples/examples/optimizer_rule.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ impl ScalarUDFImpl for MyEq {
208208
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
209209
// this example simply returns "true" which is not what a real
210210
// implementation would do.
211-
Ok(ColumnarValue::Scalar(ScalarValue::from(true)))
211+
Ok(ColumnarValue::from(ScalarValue::from(true)))
212212
}
213213
}
214214

datafusion/common/src/scalar/mod.rs

+23-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use crate::hash_utils::create_hashes;
4141
use crate::utils::{
4242
array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array,
4343
};
44-
use arrow::compute::kernels::numeric::*;
44+
use arrow::compute::kernels::{self, numeric::*};
4545
use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions};
4646
use arrow::{
4747
array::*,
@@ -1704,6 +1704,18 @@ impl ScalarValue {
17041704
Some(sv) => sv.data_type(),
17051705
};
17061706

1707+
Self::iter_to_array_of_type(scalars, &data_type)
1708+
}
1709+
1710+
/// Same as [`Self::iter_to_array`] but the target `data_type` can be
1711+
/// manually specified instead of being implicitly derived from the type of
1712+
/// the first value of the iterator.
1713+
pub fn iter_to_array_of_type(
1714+
scalars: impl IntoIterator<Item = ScalarValue>,
1715+
data_type: &DataType,
1716+
) -> Result<ArrayRef> {
1717+
let mut scalars = scalars.into_iter().peekable();
1718+
17071719
/// Creates an array of $ARRAY_TY by unpacking values of
17081720
/// SCALAR_TY for primitive types
17091721
macro_rules! build_array_primitive {
@@ -2179,6 +2191,16 @@ impl ScalarValue {
21792191
Arc::new(array_into_large_list_array(values))
21802192
}
21812193

2194+
pub fn to_array_of_size_and_type(
2195+
&self,
2196+
size: usize,
2197+
target_type: &DataType,
2198+
) -> Result<ArrayRef> {
2199+
let array = self.to_array_of_size(size)?;
2200+
let cast_array = kernels::cast::cast(&array, target_type)?;
2201+
Ok(cast_array)
2202+
}
2203+
21822204
/// Converts a scalar value into an array of `size` rows.
21832205
///
21842206
/// # Errors

datafusion/core/src/physical_optimizer/pruning.rs

+10-8
Original file line numberDiff line numberDiff line change
@@ -687,14 +687,16 @@ impl BoolVecBuilder {
687687
ColumnarValue::Array(array) => {
688688
self.combine_array(array.as_boolean());
689689
}
690-
ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) => {
691-
// False means all containers can not pass the predicate
692-
self.inner = vec![false; self.inner.len()];
693-
}
694-
_ => {
695-
// Null or true means the rows in container may pass this
696-
// conjunct so we can't prune any containers based on that
697-
}
690+
ColumnarValue::Scalar(scalar) => match scalar.value() {
691+
ScalarValue::Boolean(Some(false)) => {
692+
// False means all containers can not pass the predicate
693+
self.inner = vec![false; self.inner.len()];
694+
}
695+
_ => {
696+
// Null or true means the rows in container may pass this
697+
// conjunct so we can't prune any containers based on that
698+
}
699+
},
698700
}
699701
}
700702

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

+13-9
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
212212
}
213213

214214
fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
215-
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
215+
Ok(ColumnarValue::from(ScalarValue::Int32(Some(100))))
216216
}
217217
}
218218

@@ -323,7 +323,7 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> {
323323
vec![DataType::Int32],
324324
DataType::Int32,
325325
Volatility::Immutable,
326-
Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))),
326+
Arc::new(move |_| Ok(ColumnarValue::from(ScalarValue::Int32(Some(1))))),
327327
));
328328

329329
// Make sure that the UDF is used instead of the built-in function
@@ -669,7 +669,10 @@ impl ScalarUDFImpl for TakeUDF {
669669
// The actual implementation
670670
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
671671
let take_idx = match &args[2] {
672-
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize,
672+
ColumnarValue::Scalar(scalar) => match scalar.value() {
673+
ScalarValue::Int64(Some(v)) if v < &2 => *v as usize,
674+
_ => unreachable!(),
675+
},
673676
_ => unreachable!(),
674677
};
675678
match &args[take_idx] {
@@ -1070,19 +1073,20 @@ impl ScalarUDFImpl for MyRegexUdf {
10701073

10711074
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
10721075
match args {
1073-
[ColumnarValue::Scalar(ScalarValue::Utf8(value))] => {
1074-
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
1075-
self.matches(value.as_deref()),
1076-
)))
1077-
}
1076+
[ColumnarValue::Scalar(scalar)] => match scalar.value() {
1077+
ScalarValue::Utf8(value) => Ok(ColumnarValue::from(
1078+
ScalarValue::Boolean(self.matches(value.as_deref())),
1079+
)),
1080+
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
1081+
},
10781082
[ColumnarValue::Array(values)] => {
10791083
let mut builder = BooleanBuilder::with_capacity(values.len());
10801084
for value in values.as_string::<i32>() {
10811085
builder.append_option(self.matches(value))
10821086
}
10831087
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
10841088
}
1085-
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
1089+
_ => unreachable!(),
10861090
}
10871091
}
10881092

datafusion/expr-common/src/columnar_value.rs

+16-11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ use datafusion_common::format::DEFAULT_CAST_OPTIONS;
2525
use datafusion_common::{internal_err, Result, ScalarValue};
2626
use std::sync::Arc;
2727

28+
use crate::scalar::Scalar;
29+
2830
/// The result of evaluating an expression.
2931
///
3032
/// [`ColumnarValue::Scalar`] represents a single value repeated any number of
@@ -89,7 +91,7 @@ pub enum ColumnarValue {
8991
/// Array of values
9092
Array(ArrayRef),
9193
/// A single value
92-
Scalar(ScalarValue),
94+
Scalar(Scalar),
9395
}
9496

9597
impl From<ArrayRef> for ColumnarValue {
@@ -100,14 +102,14 @@ impl From<ArrayRef> for ColumnarValue {
100102

101103
impl From<ScalarValue> for ColumnarValue {
102104
fn from(value: ScalarValue) -> Self {
103-
ColumnarValue::Scalar(value)
105+
ColumnarValue::Scalar(value.into())
104106
}
105107
}
106108

107109
impl ColumnarValue {
108-
pub fn data_type(&self) -> DataType {
110+
pub fn data_type(&self) -> &DataType {
109111
match self {
110-
ColumnarValue::Array(array_value) => array_value.data_type().clone(),
112+
ColumnarValue::Array(array_value) => array_value.data_type(),
111113
ColumnarValue::Scalar(scalar_value) => scalar_value.data_type(),
112114
}
113115
}
@@ -195,9 +197,12 @@ impl ColumnarValue {
195197
kernels::cast::cast_with_options(array, cast_type, &cast_options)?,
196198
)),
197199
ColumnarValue::Scalar(scalar) => {
200+
// TODO(@notfilippo, logical vs physical): if `scalar.data_type` is *logically equivalent*
201+
// to `cast_type` then skip the kernel cast and only change the `data_type` of the scalar.
202+
198203
let scalar_array =
199204
if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) {
200-
if let ScalarValue::Float64(Some(float_ts)) = scalar {
205+
if let ScalarValue::Float64(Some(float_ts)) = scalar.value() {
201206
ScalarValue::Int64(Some(
202207
(float_ts * 1_000_000_000_f64).trunc() as i64,
203208
))
@@ -213,7 +218,7 @@ impl ColumnarValue {
213218
cast_type,
214219
&cast_options,
215220
)?;
216-
let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
221+
let cast_scalar = Scalar::try_from_array(&cast_array, 0)?;
217222
Ok(ColumnarValue::Scalar(cast_scalar))
218223
}
219224
}
@@ -250,7 +255,7 @@ mod tests {
250255
TestCase {
251256
input: vec![
252257
ColumnarValue::Array(make_array(1, 3)),
253-
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
258+
ColumnarValue::from(ScalarValue::Int32(Some(100))),
254259
],
255260
expected: vec![
256261
make_array(1, 3),
@@ -260,7 +265,7 @@ mod tests {
260265
// scalar and array
261266
TestCase {
262267
input: vec![
263-
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
268+
ColumnarValue::from(ScalarValue::Int32(Some(100))),
264269
ColumnarValue::Array(make_array(1, 3)),
265270
],
266271
expected: vec![
@@ -271,9 +276,9 @@ mod tests {
271276
// multiple scalars and array
272277
TestCase {
273278
input: vec![
274-
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
279+
ColumnarValue::from(ScalarValue::Int32(Some(100))),
275280
ColumnarValue::Array(make_array(1, 3)),
276-
ColumnarValue::Scalar(ScalarValue::Int32(Some(200))),
281+
ColumnarValue::from(ScalarValue::Int32(Some(200))),
277282
],
278283
expected: vec![
279284
make_array(100, 3), // scalar is expanded
@@ -306,7 +311,7 @@ mod tests {
306311
fn values_to_arrays_mixed_length_and_scalar() {
307312
ColumnarValue::values_to_arrays(&[
308313
ColumnarValue::Array(make_array(1, 3)),
309-
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
314+
ColumnarValue::from(ScalarValue::Int32(Some(100))),
310315
ColumnarValue::Array(make_array(2, 7)),
311316
])
312317
.unwrap();

datafusion/expr-common/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub mod columnar_value;
3131
pub mod groups_accumulator;
3232
pub mod interval_arithmetic;
3333
pub mod operator;
34+
pub mod scalar;
3435
pub mod signature;
3536
pub mod sort_properties;
3637
pub mod type_coercion;

0 commit comments

Comments
 (0)