|
1 | 1 | use std::str::Utf8Error;
|
| 2 | +use std::sync::Arc; |
2 | 3 |
|
3 | 4 | use datafusion::arrow::array::{
|
4 |
| - Array, ArrayRef, AsArray, Int64Array, LargeStringArray, StringArray, StringViewArray, UInt64Array, |
| 5 | + Array, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, StringArray, |
| 6 | + StringViewArray, UInt64Array, UnionArray, |
5 | 7 | };
|
6 |
| -use datafusion::arrow::datatypes::DataType; |
| 8 | +use datafusion::arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}; |
| 9 | +use datafusion::arrow::downcast_dictionary_array; |
7 | 10 | use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
|
8 | 11 | use datafusion::logical_expr::ColumnarValue;
|
9 | 12 | use jiter::{Jiter, JiterError, Peek};
|
10 | 13 |
|
11 |
| -use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array}; |
| 14 | +use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL}; |
12 | 15 |
|
13 | 16 | /// General implementation of `ScalarUDFImpl::return_type`.
|
14 | 17 | ///
|
@@ -164,21 +167,32 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
|
164 | 167 | jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
|
165 | 168 | object_lookup: bool,
|
166 | 169 | ) -> DataFusionResult<ArrayRef> {
|
167 |
| - if let Some(d) = json_array.as_any_dictionary_opt() { |
168 |
| - let a = zip_apply(d.values(), path_array, to_array, jiter_find, object_lookup)?; |
169 |
| - return Ok(d.with_values(a).into()); |
170 |
| - } |
171 |
| - let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() { |
172 |
| - zip_apply_iter(string_array.iter(), path_array, jiter_find) |
173 |
| - } else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() { |
174 |
| - zip_apply_iter(large_string_array.iter(), path_array, jiter_find) |
175 |
| - } else if let Some(string_view) = json_array.as_any().downcast_ref::<StringViewArray>() { |
176 |
| - zip_apply_iter(string_view.iter(), path_array, jiter_find) |
177 |
| - } else if let Some(string_array) = nested_json_array(json_array, object_lookup) { |
178 |
| - zip_apply_iter(string_array.iter(), path_array, jiter_find) |
179 |
| - } else { |
180 |
| - return exec_err!("unexpected json array type {:?}", json_array.data_type()); |
181 |
| - }; |
| 170 | + // arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332 |
| 171 | + use datafusion::arrow::datatypes as arrow_schema; |
| 172 | + |
| 173 | + let c = downcast_dictionary_array!( |
| 174 | + json_array => { |
| 175 | + let values = zip_apply(json_array.values(), path_array, to_array, jiter_find, object_lookup)?; |
| 176 | + if !is_json_union(values.data_type()) { |
| 177 | + return Ok(Arc::new(json_array.with_values(values))); |
| 178 | + } |
| 179 | + // JSON union: post-process the array to set keys to null where the union member is null |
| 180 | + let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids(); |
| 181 | + return Ok(Arc::new(DictionaryArray::new( |
| 182 | + mask_dictionary_keys(json_array.keys(), type_ids), |
| 183 | + values, |
| 184 | + ))); |
| 185 | + } |
| 186 | + DataType::Utf8 => zip_apply_iter(json_array.as_string::<i32>().iter(), path_array, jiter_find), |
| 187 | + DataType::LargeUtf8 => zip_apply_iter(json_array.as_string::<i64>().iter(), path_array, jiter_find), |
| 188 | + DataType::Utf8View => zip_apply_iter(json_array.as_string_view().iter(), path_array, jiter_find), |
| 189 | + other => if let Some(string_array) = nested_json_array(json_array, object_lookup) { |
| 190 | + zip_apply_iter(string_array.iter(), path_array, jiter_find) |
| 191 | + } else { |
| 192 | + return exec_err!("unexpected json array type {:?}", other); |
| 193 | + } |
| 194 | + ); |
| 195 | + |
182 | 196 | to_array(c)
|
183 | 197 | }
|
184 | 198 |
|
@@ -229,22 +243,31 @@ fn scalar_apply<C: FromIterator<Option<I>>, I>(
|
229 | 243 | to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
|
230 | 244 | jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
|
231 | 245 | ) -> DataFusionResult<ArrayRef> {
|
232 |
| - if let Some(d) = json_array.as_any_dictionary_opt() { |
233 |
| - let a = scalar_apply(d.values(), path, to_array, jiter_find)?; |
234 |
| - return Ok(d.with_values(a).into()); |
235 |
| - } |
| 246 | + // arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332 |
| 247 | + use datafusion::arrow::datatypes as arrow_schema; |
236 | 248 |
|
237 |
| - let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() { |
238 |
| - scalar_apply_iter(string_array.iter(), path, jiter_find) |
239 |
| - } else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() { |
240 |
| - scalar_apply_iter(large_string_array.iter(), path, jiter_find) |
241 |
| - } else if let Some(string_view_array) = json_array.as_any().downcast_ref::<StringViewArray>() { |
242 |
| - scalar_apply_iter(string_view_array.iter(), path, jiter_find) |
243 |
| - } else if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) { |
244 |
| - scalar_apply_iter(string_array.iter(), path, jiter_find) |
245 |
| - } else { |
246 |
| - return exec_err!("unexpected json array type {:?}", json_array.data_type()); |
247 |
| - }; |
| 249 | + let c = downcast_dictionary_array!( |
| 250 | + json_array => { |
| 251 | + let values = scalar_apply(json_array.values(), path, to_array, jiter_find)?; |
| 252 | + if !is_json_union(values.data_type()) { |
| 253 | + return Ok(Arc::new(json_array.with_values(values))); |
| 254 | + } |
| 255 | + // JSON union: post-process the array to set keys to null where the union member is null |
| 256 | + let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids(); |
| 257 | + return Ok(Arc::new(DictionaryArray::new( |
| 258 | + mask_dictionary_keys(json_array.keys(), type_ids), |
| 259 | + values, |
| 260 | + ))); |
| 261 | + } |
| 262 | + DataType::Utf8 => scalar_apply_iter(json_array.as_string::<i32>().iter(), path, jiter_find), |
| 263 | + DataType::LargeUtf8 => scalar_apply_iter(json_array.as_string::<i64>().iter(), path, jiter_find), |
| 264 | + DataType::Utf8View => scalar_apply_iter(json_array.as_string_view().iter(), path, jiter_find), |
| 265 | + other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) { |
| 266 | + scalar_apply_iter(string_array.iter(), path, jiter_find) |
| 267 | + } else { |
| 268 | + return exec_err!("unexpected json array type {:?}", other); |
| 269 | + } |
| 270 | + ); |
248 | 271 |
|
249 | 272 | to_array(c)
|
250 | 273 | }
|
@@ -319,3 +342,23 @@ impl From<Utf8Error> for GetError {
|
319 | 342 | GetError
|
320 | 343 | }
|
321 | 344 | }
|
| 345 | + |
| 346 | +/// Set keys to null where the union member is null. |
| 347 | +/// |
| 348 | +/// This is a workaround to <https://github.com/apache/arrow-rs/issues/6017#issuecomment-2352756753> |
| 349 | +/// - i.e. that dictionary null is most reliably done if the keys are null. |
| 350 | +/// |
| 351 | +/// That said, doing this might also be an optimization for cases like null-checking without needing |
| 352 | +/// to check the value union array. |
| 353 | +fn mask_dictionary_keys<K: ArrowPrimitiveType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> { |
| 354 | + let mut null_mask = vec![true; keys.len()]; |
| 355 | + for (i, k) in keys.iter().enumerate() { |
| 356 | + match k { |
| 357 | + // if the key is non-null and value is non-null, don't mask it out |
| 358 | + Some(k) if type_ids[k.as_usize()] != TYPE_ID_NULL => {} |
| 359 | + // i.e. key is null or value is null here |
| 360 | + _ => null_mask[i] = false, |
| 361 | + } |
| 362 | + } |
| 363 | + PrimitiveArray::new(keys.values().clone(), Some(null_mask.into())) |
| 364 | +} |
0 commit comments