Skip to content

Commit 4b62c80

Browse files
authored
Fix for column name based projection mask creation (#8447)
# Which issue does this PR close? - Closes #8443. # Rationale for this change See issue. # What changes are included in this PR? Replaces `starts_with` on concatenated paths with an element-by-element search of path components. # Are these changes tested? Yes, added a new test based on the issue # Are there any user-facing changes? No, just a bug fix
1 parent 60ce764 commit 4b62c80

File tree

1 file changed

+25
-27
lines changed

1 file changed

+25
-27
lines changed

parquet/src/arrow/mod.rs

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -193,14 +193,12 @@ pub mod async_writer;
193193
mod record_reader;
194194
experimental!(mod schema);
195195

196-
use std::sync::Arc;
197-
198196
pub use self::arrow_writer::ArrowWriter;
199197
#[cfg(feature = "async")]
200198
pub use self::async_reader::ParquetRecordBatchStreamBuilder;
201199
#[cfg(feature = "async")]
202200
pub use self::async_writer::AsyncArrowWriter;
203-
use crate::schema::types::{SchemaDescriptor, Type};
201+
use crate::schema::types::SchemaDescriptor;
204202
use arrow_schema::{FieldRef, Schema};
205203

206204
pub use self::schema::{
@@ -318,21 +316,6 @@ impl ProjectionMask {
318316
Self { mask: Some(mask) }
319317
}
320318

321-
// Given a starting point in the schema, do a DFS for that node adding leaf paths to `paths`.
322-
fn find_leaves(root: &Arc<Type>, parent: Option<&String>, paths: &mut Vec<String>) {
323-
let path = parent
324-
.map(|p| [p, root.name()].join("."))
325-
.unwrap_or(root.name().to_string());
326-
if root.is_group() {
327-
for child in root.get_fields() {
328-
Self::find_leaves(child, Some(&path), paths);
329-
}
330-
} else {
331-
// Reached a leaf, add to paths
332-
paths.push(path);
333-
}
334-
}
335-
336319
/// Create a [`ProjectionMask`] which selects only the named columns
337320
///
338321
/// All leaf columns that fall below a given name will be selected. For example, given
@@ -360,21 +343,24 @@ impl ProjectionMask {
360343
/// Note: repeated or out of order indices will not impact the final mask.
361344
///
362345
/// i.e. `["b", "c"]` will construct the same mask as `["c", "b", "c"]`.
346+
///
347+
/// Also, this will not produce the desired results if a column contains a '.' in its name.
348+
/// Use [`Self::leaves`] or [`Self::roots`] in that case.
363349
pub fn columns<'a>(
364350
schema: &SchemaDescriptor,
365351
names: impl IntoIterator<Item = &'a str>,
366352
) -> Self {
367-
// first make vector of paths for leaf columns
368-
let mut paths: Vec<String> = vec![];
369-
for root in schema.root_schema().get_fields() {
370-
Self::find_leaves(root, None, &mut paths);
371-
}
372-
assert_eq!(paths.len(), schema.num_columns());
373-
374353
let mut mask = vec![false; schema.num_columns()];
375354
for name in names {
376-
for idx in 0..schema.num_columns() {
377-
if paths[idx].starts_with(name) {
355+
let name_path: Vec<&str> = name.split('.').collect();
356+
for (idx, col) in schema.columns().iter().enumerate() {
357+
let path = col.path().parts();
358+
// searching for "a.b.c" cannot match "a.b"
359+
if name_path.len() > path.len() {
360+
continue;
361+
}
362+
// now path >= name_path, so check that each element in name_path matches
363+
if name_path.iter().zip(path.iter()).all(|(a, b)| a == b) {
378364
mask[idx] = true;
379365
}
380366
}
@@ -698,6 +684,18 @@ mod test {
698684

699685
let mask = ProjectionMask::columns(&schema, ["a", "e"]);
700686
assert_eq!(mask.mask.unwrap(), [true, false, true, false, true]);
687+
688+
let message_type = "
689+
message test_schema {
690+
OPTIONAL INT32 a;
691+
OPTIONAL INT32 aa;
692+
}
693+
";
694+
let parquet_group_type = parse_message_type(message_type).unwrap();
695+
let schema = SchemaDescriptor::new(Arc::new(parquet_group_type));
696+
697+
let mask = ProjectionMask::columns(&schema, ["a"]);
698+
assert_eq!(mask.mask.unwrap(), [true, false]);
701699
}
702700

703701
#[test]

0 commit comments

Comments
 (0)