Skip to content

Commit f3d5366

Browse files
authored
set keys to null where applicable in dictionary-encoded results (#46)
1 parent 4a33b9f commit f3d5366

File tree

5 files changed

+148
-47
lines changed

5 files changed

+148
-47
lines changed

src/common.rs

Lines changed: 76 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
use std::str::Utf8Error;
2+
use std::sync::Arc;
23

34
use datafusion::arrow::array::{
4-
Array, ArrayRef, AsArray, Int64Array, LargeStringArray, StringArray, StringViewArray, UInt64Array,
5+
Array, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, StringArray,
6+
StringViewArray, UInt64Array, UnionArray,
57
};
6-
use datafusion::arrow::datatypes::DataType;
8+
use datafusion::arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType};
9+
use datafusion::arrow::downcast_dictionary_array;
710
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
811
use datafusion::logical_expr::ColumnarValue;
912
use jiter::{Jiter, JiterError, Peek};
1013

11-
use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array};
14+
use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL};
1215

1316
/// General implementation of `ScalarUDFImpl::return_type`.
1417
///
@@ -164,21 +167,32 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
164167
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
165168
object_lookup: bool,
166169
) -> DataFusionResult<ArrayRef> {
167-
if let Some(d) = json_array.as_any_dictionary_opt() {
168-
let a = zip_apply(d.values(), path_array, to_array, jiter_find, object_lookup)?;
169-
return Ok(d.with_values(a).into());
170-
}
171-
let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
172-
zip_apply_iter(string_array.iter(), path_array, jiter_find)
173-
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
174-
zip_apply_iter(large_string_array.iter(), path_array, jiter_find)
175-
} else if let Some(string_view) = json_array.as_any().downcast_ref::<StringViewArray>() {
176-
zip_apply_iter(string_view.iter(), path_array, jiter_find)
177-
} else if let Some(string_array) = nested_json_array(json_array, object_lookup) {
178-
zip_apply_iter(string_array.iter(), path_array, jiter_find)
179-
} else {
180-
return exec_err!("unexpected json array type {:?}", json_array.data_type());
181-
};
170+
// arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
171+
use datafusion::arrow::datatypes as arrow_schema;
172+
173+
let c = downcast_dictionary_array!(
174+
json_array => {
175+
let values = zip_apply(json_array.values(), path_array, to_array, jiter_find, object_lookup)?;
176+
if !is_json_union(values.data_type()) {
177+
return Ok(Arc::new(json_array.with_values(values)));
178+
}
179+
// JSON union: post-process the array to set keys to null where the union member is null
180+
let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
181+
return Ok(Arc::new(DictionaryArray::new(
182+
mask_dictionary_keys(json_array.keys(), type_ids),
183+
values,
184+
)));
185+
}
186+
DataType::Utf8 => zip_apply_iter(json_array.as_string::<i32>().iter(), path_array, jiter_find),
187+
DataType::LargeUtf8 => zip_apply_iter(json_array.as_string::<i64>().iter(), path_array, jiter_find),
188+
DataType::Utf8View => zip_apply_iter(json_array.as_string_view().iter(), path_array, jiter_find),
189+
other => if let Some(string_array) = nested_json_array(json_array, object_lookup) {
190+
zip_apply_iter(string_array.iter(), path_array, jiter_find)
191+
} else {
192+
return exec_err!("unexpected json array type {:?}", other);
193+
}
194+
);
195+
182196
to_array(c)
183197
}
184198

@@ -229,22 +243,31 @@ fn scalar_apply<C: FromIterator<Option<I>>, I>(
229243
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
230244
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
231245
) -> DataFusionResult<ArrayRef> {
232-
if let Some(d) = json_array.as_any_dictionary_opt() {
233-
let a = scalar_apply(d.values(), path, to_array, jiter_find)?;
234-
return Ok(d.with_values(a).into());
235-
}
246+
// arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
247+
use datafusion::arrow::datatypes as arrow_schema;
236248

237-
let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
238-
scalar_apply_iter(string_array.iter(), path, jiter_find)
239-
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
240-
scalar_apply_iter(large_string_array.iter(), path, jiter_find)
241-
} else if let Some(string_view_array) = json_array.as_any().downcast_ref::<StringViewArray>() {
242-
scalar_apply_iter(string_view_array.iter(), path, jiter_find)
243-
} else if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
244-
scalar_apply_iter(string_array.iter(), path, jiter_find)
245-
} else {
246-
return exec_err!("unexpected json array type {:?}", json_array.data_type());
247-
};
249+
let c = downcast_dictionary_array!(
250+
json_array => {
251+
let values = scalar_apply(json_array.values(), path, to_array, jiter_find)?;
252+
if !is_json_union(values.data_type()) {
253+
return Ok(Arc::new(json_array.with_values(values)));
254+
}
255+
// JSON union: post-process the array to set keys to null where the union member is null
256+
let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
257+
return Ok(Arc::new(DictionaryArray::new(
258+
mask_dictionary_keys(json_array.keys(), type_ids),
259+
values,
260+
)));
261+
}
262+
DataType::Utf8 => scalar_apply_iter(json_array.as_string::<i32>().iter(), path, jiter_find),
263+
DataType::LargeUtf8 => scalar_apply_iter(json_array.as_string::<i64>().iter(), path, jiter_find),
264+
DataType::Utf8View => scalar_apply_iter(json_array.as_string_view().iter(), path, jiter_find),
265+
other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
266+
scalar_apply_iter(string_array.iter(), path, jiter_find)
267+
} else {
268+
return exec_err!("unexpected json array type {:?}", other);
269+
}
270+
);
248271

249272
to_array(c)
250273
}
@@ -319,3 +342,23 @@ impl From<Utf8Error> for GetError {
319342
GetError
320343
}
321344
}
345+
346+
/// Set keys to null where the union member is null.
347+
///
348+
/// This is a workaround to <https://github.com/apache/arrow-rs/issues/6017#issuecomment-2352756753>
349+
/// - i.e. that dictionary null is most reliably done if the keys are null.
350+
///
351+
/// That said, doing this might also be an optimization for cases like null-checking without needing
352+
/// to check the value union array.
353+
fn mask_dictionary_keys<K: ArrowPrimitiveType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
354+
let mut null_mask = vec![true; keys.len()];
355+
for (i, k) in keys.iter().enumerate() {
356+
match k {
357+
// if the key is non-null and value is non-null, don't mask it out
358+
Some(k) if type_ids[k.as_usize()] != TYPE_ID_NULL => {}
359+
// i.e. key is null or value is null here
360+
_ => null_mask[i] = false,
361+
}
362+
}
363+
PrimitiveArray::new(keys.values().clone(), Some(null_mask.into()))
364+
}

src/common_union.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ pub(crate) enum JsonUnionField {
141141
Object(String),
142142
}
143143

144-
const TYPE_ID_NULL: i8 = 0;
144+
pub(crate) const TYPE_ID_NULL: i8 = 0;
145145
const TYPE_ID_BOOL: i8 = 1;
146146
const TYPE_ID_INT: i8 = 2;
147147
const TYPE_ID_FLOAT: i8 = 3;

src/rewrite.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option<Transformed<Expr>> {
7878
fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
7979
match expr {
8080
Expr::ScalarFunction(func) => Some(func),
81-
Expr::Alias(alias) => extract_scalar_function(&*alias.expr),
81+
Expr::Alias(alias) => extract_scalar_function(&alias.expr),
8282
_ => None,
8383
}
8484
}

tests/main.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use datafusion::common::ScalarValue;
44
use datafusion::logical_expr::ColumnarValue;
55

66
use datafusion_functions_json::udfs::json_get_str_udf;
7-
use utils::{display_val, logical_plan, run_query, run_query_large, run_query_params};
7+
use utils::{display_val, logical_plan, run_query, run_query_dict, run_query_large, run_query_params};
88

99
mod utils;
1010

@@ -1072,6 +1072,28 @@ async fn test_arrow_union_is_null() {
10721072
assert_batches_eq!(expected, &batches);
10731073
}
10741074

1075+
#[tokio::test]
1076+
async fn test_arrow_union_is_null_dict_encoded() {
1077+
let batches = run_query_dict("select name, (json_data->'foo') is null from test")
1078+
.await
1079+
.unwrap();
1080+
1081+
let expected = [
1082+
"+------------------+---------------------------------------+",
1083+
"| name | test.json_data -> Utf8(\"foo\") IS NULL |",
1084+
"+------------------+---------------------------------------+",
1085+
"| object_foo | false |",
1086+
"| object_foo_array | false |",
1087+
"| object_foo_obj | false |",
1088+
"| object_foo_null | true |",
1089+
"| object_bar | true |",
1090+
"| list_foo | true |",
1091+
"| invalid_json | true |",
1092+
"+------------------+---------------------------------------+",
1093+
];
1094+
assert_batches_eq!(expected, &batches);
1095+
}
1096+
10751097
#[tokio::test]
10761098
async fn test_arrow_union_is_not_null() {
10771099
let batches = run_query("select name, (json_data->'foo') is not null from test")
@@ -1094,6 +1116,28 @@ async fn test_arrow_union_is_not_null() {
10941116
assert_batches_eq!(expected, &batches);
10951117
}
10961118

1119+
#[tokio::test]
1120+
async fn test_arrow_union_is_not_null_dict_encoded() {
1121+
let batches = run_query_dict("select name, (json_data->'foo') is not null from test")
1122+
.await
1123+
.unwrap();
1124+
1125+
let expected = [
1126+
"+------------------+-------------------------------------------+",
1127+
"| name | test.json_data -> Utf8(\"foo\") IS NOT NULL |",
1128+
"+------------------+-------------------------------------------+",
1129+
"| object_foo | true |",
1130+
"| object_foo_array | true |",
1131+
"| object_foo_obj | true |",
1132+
"| object_foo_null | false |",
1133+
"| object_bar | false |",
1134+
"| list_foo | false |",
1135+
"| invalid_json | false |",
1136+
"+------------------+-------------------------------------------+",
1137+
];
1138+
assert_batches_eq!(expected, &batches);
1139+
}
1140+
10971141
#[tokio::test]
10981142
async fn test_arrow_scalar_union_is_null() {
10991143
let batches = run_query(
@@ -1147,8 +1191,8 @@ async fn test_dict_haystack() {
11471191
"| v |",
11481192
"+-----------------------+",
11491193
"| {object={\"bar\": [0]}} |",
1150-
"| {null=} |",
1151-
"| {null=} |",
1194+
"| |",
1195+
"| |",
11521196
"+-----------------------+",
11531197
];
11541198

@@ -1164,8 +1208,8 @@ async fn test_dict_haystack_needle() {
11641208
"| v |",
11651209
"+-------------+",
11661210
"| {array=[0]} |",
1167-
"| {null=} |",
1168-
"| {null=} |",
1211+
"| |",
1212+
"| |",
11691213
"+-------------+",
11701214
];
11711215

tests/utils/mod.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
use std::sync::Arc;
33

44
use datafusion::arrow::array::{
5-
ArrayRef, DictionaryArray, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array,
5+
ArrayRef, DictionaryArray, Int32Array, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array,
66
};
7-
use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, UInt32Type, UInt8Type};
7+
use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Int64Type, Schema, UInt32Type, UInt8Type};
88
use datafusion::arrow::util::display::{ArrayFormatter, FormatOptions};
99
use datafusion::arrow::{array::LargeStringArray, array::StringArray, record_batch::RecordBatch};
1010
use datafusion::common::ParamValues;
@@ -13,7 +13,7 @@ use datafusion::execution::context::SessionContext;
1313
use datafusion::prelude::SessionConfig;
1414
use datafusion_functions_json::register_all;
1515

16-
async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
16+
async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result<SessionContext> {
1717
let config = SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres");
1818
let mut ctx = SessionContext::new_with_config(config);
1919
register_all(&mut ctx)?;
@@ -28,11 +28,20 @@ async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
2828
("invalid_json", "is not json"),
2929
];
3030
let json_values = test_data.iter().map(|(_, json)| *json).collect::<Vec<_>>();
31-
let (json_data_type, json_array): (DataType, ArrayRef) = if large_utf8 {
31+
let (mut json_data_type, mut json_array): (DataType, ArrayRef) = if large_utf8 {
3232
(DataType::LargeUtf8, Arc::new(LargeStringArray::from(json_values)))
3333
} else {
3434
(DataType::Utf8, Arc::new(StringArray::from(json_values)))
3535
};
36+
37+
if dict_encoded {
38+
json_data_type = DataType::Dictionary(DataType::Int32.into(), json_data_type.into());
39+
json_array = Arc::new(DictionaryArray::<Int32Type>::new(
40+
Int32Array::from_iter_values(0..(json_array.len() as i32)),
41+
json_array,
42+
));
43+
}
44+
3645
let test_batch = RecordBatch::try_new(
3746
Arc::new(Schema::new(vec![
3847
Field::new("name", DataType::Utf8, false),
@@ -178,12 +187,17 @@ async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
178187
}
179188

180189
pub async fn run_query(sql: &str) -> Result<Vec<RecordBatch>> {
181-
let ctx = create_test_table(false).await?;
190+
let ctx = create_test_table(false, false).await?;
182191
ctx.sql(sql).await?.collect().await
183192
}
184193

185194
pub async fn run_query_large(sql: &str) -> Result<Vec<RecordBatch>> {
186-
let ctx = create_test_table(true).await?;
195+
let ctx = create_test_table(true, false).await?;
196+
ctx.sql(sql).await?.collect().await
197+
}
198+
199+
pub async fn run_query_dict(sql: &str) -> Result<Vec<RecordBatch>> {
200+
let ctx = create_test_table(false, true).await?;
187201
ctx.sql(sql).await?.collect().await
188202
}
189203

@@ -192,7 +206,7 @@ pub async fn run_query_params(
192206
large_utf8: bool,
193207
query_values: impl Into<ParamValues>,
194208
) -> Result<Vec<RecordBatch>> {
195-
let ctx = create_test_table(large_utf8).await?;
209+
let ctx = create_test_table(large_utf8, false).await?;
196210
ctx.sql(sql).await?.with_param_values(query_values)?.collect().await
197211
}
198212

0 commit comments

Comments
 (0)