@@ -193,14 +193,12 @@ pub mod async_writer;
193193mod record_reader;
194194experimental ! ( mod schema) ;
195195
196- use std:: sync:: Arc ;
197-
198196pub use self :: arrow_writer:: ArrowWriter ;
199197#[ cfg( feature = "async" ) ]
200198pub use self :: async_reader:: ParquetRecordBatchStreamBuilder ;
201199#[ cfg( feature = "async" ) ]
202200pub use self :: async_writer:: AsyncArrowWriter ;
203- use crate :: schema:: types:: { SchemaDescriptor , Type } ;
201+ use crate :: schema:: types:: SchemaDescriptor ;
204202use arrow_schema:: { FieldRef , Schema } ;
205203
206204pub use self :: schema:: {
@@ -318,21 +316,6 @@ impl ProjectionMask {
318316 Self { mask : Some ( mask) }
319317 }
320318
321- // Given a starting point in the schema, do a DFS for that node adding leaf paths to `paths`.
322- fn find_leaves ( root : & Arc < Type > , parent : Option < & String > , paths : & mut Vec < String > ) {
323- let path = parent
324- . map ( |p| [ p, root. name ( ) ] . join ( "." ) )
325- . unwrap_or ( root. name ( ) . to_string ( ) ) ;
326- if root. is_group ( ) {
327- for child in root. get_fields ( ) {
328- Self :: find_leaves ( child, Some ( & path) , paths) ;
329- }
330- } else {
331- // Reached a leaf, add to paths
332- paths. push ( path) ;
333- }
334- }
335-
336319 /// Create a [`ProjectionMask`] which selects only the named columns
337320 ///
338321 /// All leaf columns that fall below a given name will be selected. For example, given
@@ -360,21 +343,24 @@ impl ProjectionMask {
360343 /// Note: repeated or out of order indices will not impact the final mask.
361344 ///
362345 /// i.e. `["b", "c"]` will construct the same mask as `["c", "b", "c"]`.
346+ ///
347+ /// Also, this will not produce the desired results if a column contains a '.' in its name.
348+ /// Use [`Self::leaves`] or [`Self::roots`] in that case.
363349 pub fn columns < ' a > (
364350 schema : & SchemaDescriptor ,
365351 names : impl IntoIterator < Item = & ' a str > ,
366352 ) -> Self {
367- // first make vector of paths for leaf columns
368- let mut paths: Vec < String > = vec ! [ ] ;
369- for root in schema. root_schema ( ) . get_fields ( ) {
370- Self :: find_leaves ( root, None , & mut paths) ;
371- }
372- assert_eq ! ( paths. len( ) , schema. num_columns( ) ) ;
373-
374353 let mut mask = vec ! [ false ; schema. num_columns( ) ] ;
375354 for name in names {
376- for idx in 0 ..schema. num_columns ( ) {
377- if paths[ idx] . starts_with ( name) {
355+ let name_path: Vec < & str > = name. split ( '.' ) . collect ( ) ;
356+ for ( idx, col) in schema. columns ( ) . iter ( ) . enumerate ( ) {
357+ let path = col. path ( ) . parts ( ) ;
358+ // searching for "a.b.c" cannot match "a.b"
359+ if name_path. len ( ) > path. len ( ) {
360+ continue ;
361+ }
362+ // now path >= name_path, so check that each element in name_path matches
363+ if name_path. iter ( ) . zip ( path. iter ( ) ) . all ( |( a, b) | a == b) {
378364 mask[ idx] = true ;
379365 }
380366 }
@@ -698,6 +684,18 @@ mod test {
698684
699685 let mask = ProjectionMask :: columns ( & schema, [ "a" , "e" ] ) ;
700686 assert_eq ! ( mask. mask. unwrap( ) , [ true , false , true , false , true ] ) ;
687+
688+ let message_type = "
689+ message test_schema {
690+ OPTIONAL INT32 a;
691+ OPTIONAL INT32 aa;
692+ }
693+ " ;
694+ let parquet_group_type = parse_message_type ( message_type) . unwrap ( ) ;
695+ let schema = SchemaDescriptor :: new ( Arc :: new ( parquet_group_type) ) ;
696+
697+ let mask = ProjectionMask :: columns ( & schema, [ "a" ] ) ;
698+ assert_eq ! ( mask. mask. unwrap( ) , [ true , false ] ) ;
701699 }
702700
703701 #[ test]
0 commit comments