Skip to content

Commit

Permalink
Make it easier to create a ScalarValure representing typed null (#14548
Browse files Browse the repository at this point in the history
…) (#14558)

* Make it easier to create a ScalarValure representing typed null (#14548)

* Fix issues causing GitHub checks to fail

* Fix issues causing GitHub checks to fail

* change TryFrom to call try_new_null to avoid the duplication

* Fix issues causing GitHub checks to fail

* Trigger GitHub Actions rerun

---------

Co-authored-by: Sergey Zhukov <[email protected]>
  • Loading branch information
cj-zhukov and Sergey Zhukov authored Feb 11, 2025
1 parent 036c8f2 commit 4f71e1c
Showing 1 changed file with 208 additions and 109 deletions.
317 changes: 208 additions & 109 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,129 @@ impl ScalarValue {
)
}

/// Create a Null instance of ScalarValue for this datatype
///
/// Example
/// ```
/// use datafusion_common::ScalarValue;
/// use arrow::datatypes::DataType;
///
/// let scalar = ScalarValue::try_new_null(&DataType::Int32).unwrap();
/// assert_eq!(scalar.is_null(), true);
/// assert_eq!(scalar.data_type(), DataType::Int32);
/// ```
pub fn try_new_null(data_type: &DataType) -> Result<Self> {
Ok(match data_type {
DataType::Boolean => ScalarValue::Boolean(None),
DataType::Float16 => ScalarValue::Float16(None),
DataType::Float64 => ScalarValue::Float64(None),
DataType::Float32 => ScalarValue::Float32(None),
DataType::Int8 => ScalarValue::Int8(None),
DataType::Int16 => ScalarValue::Int16(None),
DataType::Int32 => ScalarValue::Int32(None),
DataType::Int64 => ScalarValue::Int64(None),
DataType::UInt8 => ScalarValue::UInt8(None),
DataType::UInt16 => ScalarValue::UInt16(None),
DataType::UInt32 => ScalarValue::UInt32(None),
DataType::UInt64 => ScalarValue::UInt64(None),
DataType::Decimal128(precision, scale) => {
ScalarValue::Decimal128(None, *precision, *scale)
}
DataType::Decimal256(precision, scale) => {
ScalarValue::Decimal256(None, *precision, *scale)
}
DataType::Utf8 => ScalarValue::Utf8(None),
DataType::LargeUtf8 => ScalarValue::LargeUtf8(None),
DataType::Utf8View => ScalarValue::Utf8View(None),
DataType::Binary => ScalarValue::Binary(None),
DataType::BinaryView => ScalarValue::BinaryView(None),
DataType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None),
DataType::LargeBinary => ScalarValue::LargeBinary(None),
DataType::Date32 => ScalarValue::Date32(None),
DataType::Date64 => ScalarValue::Date64(None),
DataType::Time32(TimeUnit::Second) => ScalarValue::Time32Second(None),
DataType::Time32(TimeUnit::Millisecond) => {
ScalarValue::Time32Millisecond(None)
}
DataType::Time64(TimeUnit::Microsecond) => {
ScalarValue::Time64Microsecond(None)
}
DataType::Time64(TimeUnit::Nanosecond) => ScalarValue::Time64Nanosecond(None),
DataType::Timestamp(TimeUnit::Second, tz_opt) => {
ScalarValue::TimestampSecond(None, tz_opt.clone())
}
DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => {
ScalarValue::TimestampMillisecond(None, tz_opt.clone())
}
DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
ScalarValue::TimestampMicrosecond(None, tz_opt.clone())
}
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
ScalarValue::TimestampNanosecond(None, tz_opt.clone())
}
DataType::Interval(IntervalUnit::YearMonth) => {
ScalarValue::IntervalYearMonth(None)
}
DataType::Interval(IntervalUnit::DayTime) => {
ScalarValue::IntervalDayTime(None)
}
DataType::Interval(IntervalUnit::MonthDayNano) => {
ScalarValue::IntervalMonthDayNano(None)
}
DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None),
DataType::Duration(TimeUnit::Millisecond) => {
ScalarValue::DurationMillisecond(None)
}
DataType::Duration(TimeUnit::Microsecond) => {
ScalarValue::DurationMicrosecond(None)
}
DataType::Duration(TimeUnit::Nanosecond) => {
ScalarValue::DurationNanosecond(None)
}
DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary(
index_type.clone(),
Box::new(value_type.as_ref().try_into()?),
),
// `ScalaValue::List` contains single element `ListArray`.
DataType::List(field_ref) => ScalarValue::List(Arc::new(
GenericListArray::new_null(Arc::clone(field_ref), 1),
)),
// `ScalarValue::LargeList` contains single element `LargeListArray`.
DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new(
GenericListArray::new_null(Arc::clone(field_ref), 1),
)),
// `ScalaValue::FixedSizeList` contains single element `FixedSizeList`.
DataType::FixedSizeList(field_ref, fixed_length) => {
ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null(
Arc::clone(field_ref),
*fixed_length,
1,
)))
}
DataType::Struct(fields) => ScalarValue::Struct(
new_null_array(&DataType::Struct(fields.to_owned()), 1)
.as_struct()
.to_owned()
.into(),
),
DataType::Map(fields, sorted) => ScalarValue::Map(
new_null_array(&DataType::Map(fields.to_owned(), sorted.to_owned()), 1)
.as_map()
.to_owned()
.into(),
),
DataType::Union(fields, mode) => {
ScalarValue::Union(None, fields.clone(), *mode)
}
DataType::Null => ScalarValue::Null,
_ => {
return _not_impl_err!(
"Can't create a null scalar from data_type \"{data_type:?}\""
);
}
})
}

/// Returns a [`ScalarValue::Utf8`] representing `val`
pub fn new_utf8(val: impl Into<String>) -> Self {
ScalarValue::from(val.into())
Expand Down Expand Up @@ -3457,115 +3580,7 @@ impl TryFrom<&DataType> for ScalarValue {

/// Create a Null instance of ScalarValue for this datatype
fn try_from(data_type: &DataType) -> Result<Self> {
Ok(match data_type {
DataType::Boolean => ScalarValue::Boolean(None),
DataType::Float16 => ScalarValue::Float16(None),
DataType::Float64 => ScalarValue::Float64(None),
DataType::Float32 => ScalarValue::Float32(None),
DataType::Int8 => ScalarValue::Int8(None),
DataType::Int16 => ScalarValue::Int16(None),
DataType::Int32 => ScalarValue::Int32(None),
DataType::Int64 => ScalarValue::Int64(None),
DataType::UInt8 => ScalarValue::UInt8(None),
DataType::UInt16 => ScalarValue::UInt16(None),
DataType::UInt32 => ScalarValue::UInt32(None),
DataType::UInt64 => ScalarValue::UInt64(None),
DataType::Decimal128(precision, scale) => {
ScalarValue::Decimal128(None, *precision, *scale)
}
DataType::Decimal256(precision, scale) => {
ScalarValue::Decimal256(None, *precision, *scale)
}
DataType::Utf8 => ScalarValue::Utf8(None),
DataType::LargeUtf8 => ScalarValue::LargeUtf8(None),
DataType::Utf8View => ScalarValue::Utf8View(None),
DataType::Binary => ScalarValue::Binary(None),
DataType::BinaryView => ScalarValue::BinaryView(None),
DataType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None),
DataType::LargeBinary => ScalarValue::LargeBinary(None),
DataType::Date32 => ScalarValue::Date32(None),
DataType::Date64 => ScalarValue::Date64(None),
DataType::Time32(TimeUnit::Second) => ScalarValue::Time32Second(None),
DataType::Time32(TimeUnit::Millisecond) => {
ScalarValue::Time32Millisecond(None)
}
DataType::Time64(TimeUnit::Microsecond) => {
ScalarValue::Time64Microsecond(None)
}
DataType::Time64(TimeUnit::Nanosecond) => ScalarValue::Time64Nanosecond(None),
DataType::Timestamp(TimeUnit::Second, tz_opt) => {
ScalarValue::TimestampSecond(None, tz_opt.clone())
}
DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => {
ScalarValue::TimestampMillisecond(None, tz_opt.clone())
}
DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
ScalarValue::TimestampMicrosecond(None, tz_opt.clone())
}
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
ScalarValue::TimestampNanosecond(None, tz_opt.clone())
}
DataType::Interval(IntervalUnit::YearMonth) => {
ScalarValue::IntervalYearMonth(None)
}
DataType::Interval(IntervalUnit::DayTime) => {
ScalarValue::IntervalDayTime(None)
}
DataType::Interval(IntervalUnit::MonthDayNano) => {
ScalarValue::IntervalMonthDayNano(None)
}
DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None),
DataType::Duration(TimeUnit::Millisecond) => {
ScalarValue::DurationMillisecond(None)
}
DataType::Duration(TimeUnit::Microsecond) => {
ScalarValue::DurationMicrosecond(None)
}
DataType::Duration(TimeUnit::Nanosecond) => {
ScalarValue::DurationNanosecond(None)
}
DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary(
index_type.clone(),
Box::new(value_type.as_ref().try_into()?),
),
// `ScalaValue::List` contains single element `ListArray`.
DataType::List(field_ref) => ScalarValue::List(Arc::new(
GenericListArray::new_null(Arc::clone(field_ref), 1),
)),
// `ScalarValue::LargeList` contains single element `LargeListArray`.
DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new(
GenericListArray::new_null(Arc::clone(field_ref), 1),
)),
// `ScalaValue::FixedSizeList` contains single element `FixedSizeList`.
DataType::FixedSizeList(field_ref, fixed_length) => {
ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null(
Arc::clone(field_ref),
*fixed_length,
1,
)))
}
DataType::Struct(fields) => ScalarValue::Struct(
new_null_array(&DataType::Struct(fields.to_owned()), 1)
.as_struct()
.to_owned()
.into(),
),
DataType::Map(fields, sorted) => ScalarValue::Map(
new_null_array(&DataType::Map(fields.to_owned(), sorted.to_owned()), 1)
.as_map()
.to_owned()
.into(),
),
DataType::Union(fields, mode) => {
ScalarValue::Union(None, fields.clone(), *mode)
}
DataType::Null => ScalarValue::Null,
_ => {
return _not_impl_err!(
"Can't create a scalar from data_type \"{data_type:?}\""
);
}
})
Self::try_new_null(data_type)
}
}

Expand Down Expand Up @@ -7269,4 +7284,88 @@ mod tests {
let dictionary_array = dictionary_scalar.to_array().unwrap();
assert!(dictionary_array.is_null(0));
}

#[test]
fn test_scalar_value_try_new_null() {
let scalars = vec![
ScalarValue::try_new_null(&DataType::Boolean).unwrap(),
ScalarValue::try_new_null(&DataType::Int8).unwrap(),
ScalarValue::try_new_null(&DataType::Int16).unwrap(),
ScalarValue::try_new_null(&DataType::Int32).unwrap(),
ScalarValue::try_new_null(&DataType::Int64).unwrap(),
ScalarValue::try_new_null(&DataType::UInt8).unwrap(),
ScalarValue::try_new_null(&DataType::UInt16).unwrap(),
ScalarValue::try_new_null(&DataType::UInt32).unwrap(),
ScalarValue::try_new_null(&DataType::UInt64).unwrap(),
ScalarValue::try_new_null(&DataType::Float16).unwrap(),
ScalarValue::try_new_null(&DataType::Float32).unwrap(),
ScalarValue::try_new_null(&DataType::Float64).unwrap(),
ScalarValue::try_new_null(&DataType::Decimal128(42, 42)).unwrap(),
ScalarValue::try_new_null(&DataType::Decimal256(42, 42)).unwrap(),
ScalarValue::try_new_null(&DataType::Utf8).unwrap(),
ScalarValue::try_new_null(&DataType::LargeUtf8).unwrap(),
ScalarValue::try_new_null(&DataType::Utf8View).unwrap(),
ScalarValue::try_new_null(&DataType::Binary).unwrap(),
ScalarValue::try_new_null(&DataType::BinaryView).unwrap(),
ScalarValue::try_new_null(&DataType::FixedSizeBinary(42)).unwrap(),
ScalarValue::try_new_null(&DataType::LargeBinary).unwrap(),
ScalarValue::try_new_null(&DataType::Date32).unwrap(),
ScalarValue::try_new_null(&DataType::Date64).unwrap(),
ScalarValue::try_new_null(&DataType::Time32(TimeUnit::Second)).unwrap(),
ScalarValue::try_new_null(&DataType::Time32(TimeUnit::Millisecond)).unwrap(),
ScalarValue::try_new_null(&DataType::Time64(TimeUnit::Microsecond)).unwrap(),
ScalarValue::try_new_null(&DataType::Time64(TimeUnit::Nanosecond)).unwrap(),
ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Second, None))
.unwrap(),
ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Millisecond, None))
.unwrap(),
ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Microsecond, None))
.unwrap(),
ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Nanosecond, None))
.unwrap(),
ScalarValue::try_new_null(&DataType::Interval(IntervalUnit::YearMonth))
.unwrap(),
ScalarValue::try_new_null(&DataType::Interval(IntervalUnit::DayTime))
.unwrap(),
ScalarValue::try_new_null(&DataType::Interval(IntervalUnit::MonthDayNano))
.unwrap(),
ScalarValue::try_new_null(&DataType::Duration(TimeUnit::Second)).unwrap(),
ScalarValue::try_new_null(&DataType::Duration(TimeUnit::Microsecond))
.unwrap(),
ScalarValue::try_new_null(&DataType::Duration(TimeUnit::Nanosecond)).unwrap(),
ScalarValue::try_new_null(&DataType::Null).unwrap(),
];
assert!(scalars.iter().all(|s| s.is_null()));

let field_ref = Arc::new(Field::new("foo", DataType::Int32, true));
let map_field_ref = Arc::new(Field::new(
"foo",
DataType::Struct(Fields::from(vec![
Field::new("bar", DataType::Utf8, true),
Field::new("baz", DataType::Int32, true),
])),
true,
));
let scalars = vec![
ScalarValue::try_new_null(&DataType::List(Arc::clone(&field_ref))).unwrap(),
ScalarValue::try_new_null(&DataType::LargeList(Arc::clone(&field_ref)))
.unwrap(),
ScalarValue::try_new_null(&DataType::FixedSizeList(
Arc::clone(&field_ref),
42,
))
.unwrap(),
ScalarValue::try_new_null(&DataType::Struct(
vec![Arc::clone(&field_ref)].into(),
))
.unwrap(),
ScalarValue::try_new_null(&DataType::Map(map_field_ref, false)).unwrap(),
ScalarValue::try_new_null(&DataType::Union(
UnionFields::new(vec![42], vec![field_ref]),
UnionMode::Dense,
))
.unwrap(),
];
assert!(scalars.iter().all(|s| s.is_null()));
}
}

0 comments on commit 4f71e1c

Please sign in to comment.