Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support lower_bound&&upper_bound for parquet writer #383

Merged
merged 6 commits into from
Jul 2, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use Datum instead of Literal
ZENOTME committed Jun 21, 2024

Verified

This commit was signed with the committer’s verified signature.
scala-steward Scala Steward
commit 36f8569f145d85e3be65367120a89bd6ba7b2a3d
223 changes: 120 additions & 103 deletions crates/iceberg/src/writer/file_writer/parquet_writer.rs
Original file line number Diff line number Diff line change
@@ -19,8 +19,8 @@

use crate::arrow::DEFAULT_MAP_FIELD_NAME;
use crate::spec::{
visit_schema, ListType, Literal, MapType, NestedFieldRef, PrimitiveType, Schema, SchemaRef,
SchemaVisitor, StructType,
visit_schema, Datum, ListType, MapType, NestedFieldRef, PrimitiveLiteral, PrimitiveType,
Schema, SchemaRef, SchemaVisitor, StructType,
};
use crate::ErrorKind;
use crate::{io::FileIO, io::FileWrite, Result};
@@ -225,8 +225,8 @@ pub struct ParquetWriter {

/// Used to aggregate min and max value of each column.
struct MinMaxColAggregator {
lower_bounds: HashMap<i32, Literal>,
upper_bounds: HashMap<i32, Literal>,
lower_bounds: HashMap<i32, Datum>,
upper_bounds: HashMap<i32, Datum>,
schema: SchemaRef,
}

@@ -259,36 +259,6 @@ impl MinMaxColAggregator {
};

macro_rules! update_stat {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I'm not a big fan of macro. We can extrace the convert_func to a method which converts
Statistics to Datum, then update this map.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same. Plain function is much easier to maintain.

($self:ident, $stat:ident, $convert_func:expr) => {
if $stat.min_is_exact() {
let val = $convert_func($stat.min().clone());
match $self.lower_bounds.entry(col_id) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
if entry.get() > &val {
entry.insert(val);
}
}
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(val);
}
}
}
if $stat.max_is_exact() {
let val = $convert_func($stat.max().clone());
match $self.upper_bounds.entry(col_id) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
if entry.get() < &val {
entry.insert(val);
}
}
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(val);
}
}
}
};
}
macro_rules! update_stat_with_err {
($self:ident, $stat:ident, $convert_func:expr) => {
if $stat.min_is_exact() {
let val = $convert_func($stat.min().clone())?;
@@ -321,41 +291,53 @@ impl MinMaxColAggregator {

match (ty, value) {
(PrimitiveType::Boolean, Statistics::Boolean(stat)) => {
update_stat!(self, stat, Literal::bool);
let convert_func = |v: bool| Result::<Datum>::Ok(Datum::bool(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Int, Statistics::Int32(stat)) => {
update_stat!(self, stat, Literal::int);
let convert_func = |v: i32| Result::<Datum>::Ok(Datum::int(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Long, Statistics::Int64(stat)) => {
update_stat!(self, stat, Literal::long);
let convert_func = |v: i64| Result::<Datum>::Ok(Datum::long(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Float, Statistics::Float(stat)) => {
update_stat!(self, stat, Literal::float);
let convert_func = |v: f32| Result::<Datum>::Ok(Datum::float(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Double, Statistics::Double(stat)) => {
update_stat!(self, stat, Literal::double);
let convert_func = |v: f64| Result::<Datum>::Ok(Datum::double(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::String, Statistics::ByteArray(stat)) => {
let convert_func = |v: ByteArray| -> Result<Literal> {
Ok(Literal::string(v.as_utf8()?.to_string()))
let convert_func = |v: ByteArray| {
Result::<Datum>::Ok(Datum::string(
String::from_utf8(v.data().to_vec()).unwrap(),
))
};
update_stat_with_err!(self, stat, convert_func);
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Binary, Statistics::ByteArray(stat)) => {
let convert_func = |v: ByteArray| Literal::binary(v.data().to_vec());
let convert_func =
|v: ByteArray| Result::<Datum>::Ok(Datum::binary(v.data().to_vec()));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Date, Statistics::Int32(stat)) => {
update_stat!(self, stat, Literal::date);
let convert_func = |v: i32| Result::<Datum>::Ok(Datum::date(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Time, Statistics::Int64(stat)) => {
update_stat!(self, stat, Literal::time);
let convert_func = |v: i64| Datum::time_micros(v);
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Timestamp, Statistics::Int64(stat)) => {
update_stat!(self, stat, Literal::timestamp);
let convert_func = |v: i64| Result::<Datum>::Ok(Datum::timestamp_micros(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Timestamptz, Statistics::Int64(stat)) => {
update_stat!(self, stat, Literal::timestamptz);
let convert_func = |v: i64| Result::<Datum>::Ok(Datum::timestamptz_micros(v));
update_stat!(self, stat, convert_func);
}
(
PrimitiveType::Decimal {
@@ -364,10 +346,15 @@ impl MinMaxColAggregator {
},
Statistics::ByteArray(stat),
) => {
let convert_func = |v: ByteArray| -> Result<Literal> {
Ok(Literal::decimal(i128::from_be_bytes(v.data().try_into()?)))
let convert_func = |v: ByteArray| -> Result<Datum> {
Result::<Datum>::Ok(Datum::new(
ty.clone(),
PrimitiveLiteral::Decimal(i128::from_le_bytes(
v.data().try_into().unwrap(),
)),
))
};
update_stat_with_err!(self, stat, convert_func);
update_stat!(self, stat, convert_func);
}
(
PrimitiveType::Decimal {
@@ -376,7 +363,12 @@ impl MinMaxColAggregator {
},
Statistics::Int32(stat),
) => {
let convert_func = |v: i32| Literal::decimal(v as i128);
let convert_func = |v: i32| {
Result::<Datum>::Ok(Datum::new(
ty.clone(),
PrimitiveLiteral::Decimal(i128::from(v)),
))
};
update_stat!(self, stat, convert_func);
}
(
@@ -386,7 +378,12 @@ impl MinMaxColAggregator {
},
Statistics::Int64(stat),
) => {
let convert_func = |v: i64| Literal::decimal(v as i128);
let convert_func = |v: i64| {
Result::<Datum>::Ok(Datum::new(
ty.clone(),
PrimitiveLiteral::Decimal(i128::from(v)),
))
};
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Uuid, Statistics::FixedLenByteArray(stat)) => {
@@ -397,11 +394,11 @@ impl MinMaxColAggregator {
"Invalid length of uuid bytes.",
));
}
Ok(Literal::uuid(Uuid::from_bytes(
Ok(Datum::uuid(Uuid::from_bytes(
v.data()[..16].try_into().unwrap(),
)))
};
update_stat_with_err!(self, stat, convert_func);
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Fixed(len), Statistics::FixedLenByteArray(stat)) => {
let convert_func = |v: FixedLenByteArray| {
@@ -411,9 +408,9 @@ impl MinMaxColAggregator {
"Invalid length of fixed bytes.",
));
}
Ok(Literal::fixed(v.data().to_vec()))
Ok(Datum::fixed(v.data().to_vec()))
};
update_stat_with_err!(self, stat, convert_func);
update_stat!(self, stat, convert_func);
}
(ty, value) => {
return Err(Error::new(
@@ -425,7 +422,7 @@ impl MinMaxColAggregator {
Ok(())
}

fn produce(self) -> (HashMap<i32, Literal>, HashMap<i32, Literal>) {
fn produce(self) -> (HashMap<i32, Datum>, HashMap<i32, Datum>) {
(self.lower_bounds, self.upper_bounds)
}
}
@@ -936,11 +933,11 @@ mod tests {
assert_eq!(*data_file.value_counts(), HashMap::from([(0, 2048)]));
assert_eq!(
*data_file.lower_bounds(),
HashMap::from([(0, Literal::long(0))])
HashMap::from([(0, Datum::long(0))])
);
assert_eq!(
*data_file.upper_bounds(),
HashMap::from([(0, Literal::long(1023))])
HashMap::from([(0, Datum::long(1023))])
);
assert_eq!(*data_file.null_value_counts(), HashMap::from([(0, 1024)]));

@@ -1145,27 +1142,27 @@ mod tests {
assert_eq!(
*data_file.lower_bounds(),
HashMap::from([
(0, Literal::long(0)),
(5, Literal::long(0)),
(6, Literal::long(0)),
(2, Literal::string("0")),
(7, Literal::long(0)),
(9, Literal::long(0)),
(11, Literal::string("0")),
(13, Literal::long(0))
(0, Datum::long(0)),
(5, Datum::long(0)),
(6, Datum::long(0)),
(2, Datum::string("0")),
(7, Datum::long(0)),
(9, Datum::long(0)),
(11, Datum::string("0")),
(13, Datum::long(0))
])
);
assert_eq!(
*data_file.upper_bounds(),
HashMap::from([
(0, Literal::long(1023)),
(5, Literal::long(1023)),
(6, Literal::long(1023)),
(2, Literal::string("999")),
(7, Literal::long(1023)),
(9, Literal::long(1023)),
(11, Literal::string("999")),
(13, Literal::long(1023))
(0, Datum::long(1023)),
(5, Datum::long(1023)),
(6, Datum::long(1023)),
(2, Datum::string("999")),
(7, Datum::long(1023)),
(9, Datum::long(1023)),
(11, Datum::string("999")),
(13, Datum::long(1023))
])
);

@@ -1315,41 +1312,61 @@ mod tests {
assert_eq!(
*data_file.lower_bounds(),
HashMap::from([
(0, Literal::bool(false)),
(1, Literal::int(1)),
(2, Literal::long(1)),
(3, Literal::float(0.5)),
(4, Literal::double(0.5)),
(5, Literal::string("a")),
(6, Literal::binary(vec![])),
(7, Literal::date(0)),
(8, Literal::time(0)),
(9, Literal::timestamp(0)),
(10, Literal::timestamptz(0)),
(11, Literal::decimal(1)),
(12, Literal::uuid(Uuid::from_u128(0))),
(13, Literal::fixed(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
(0, Datum::bool(false)),
(1, Datum::int(1)),
(2, Datum::long(1)),
(3, Datum::float(0.5)),
(4, Datum::double(0.5)),
(5, Datum::string("a")),
(6, Datum::binary(vec![])),
(7, Datum::date(0)),
(8, Datum::time_micros(0).unwrap()),
(9, Datum::timestamp_micros(0)),
(10, Datum::timestamptz_micros(0)),
(
11,
Datum::new(
PrimitiveType::Decimal {
precision: 10,
scale: 5
},
PrimitiveLiteral::Decimal(1)
)
),
(12, Datum::uuid(Uuid::from_u128(0))),
(13, Datum::fixed(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
(12, Datum::uuid(Uuid::from_u128(0))),
(13, Datum::fixed(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
])
);
assert_eq!(
*data_file.upper_bounds(),
HashMap::from([
(0, Literal::bool(true)),
(1, Literal::int(4)),
(2, Literal::long(4)),
(3, Literal::float(3.5)),
(4, Literal::double(3.5)),
(5, Literal::string("d")),
(6, Literal::binary(vec![122, 122, 122, 122])),
(7, Literal::date(3)),
(8, Literal::time(3)),
(9, Literal::timestamp(3)),
(10, Literal::timestamptz(3)),
(11, Literal::decimal(100)),
(12, Literal::uuid(Uuid::from_u128(3))),
(0, Datum::bool(true)),
(1, Datum::int(4)),
(2, Datum::long(4)),
(3, Datum::float(3.5)),
(4, Datum::double(3.5)),
(5, Datum::string("d")),
(6, Datum::binary(vec![122, 122, 122, 122])),
(7, Datum::date(3)),
(8, Datum::time_micros(3).unwrap()),
(9, Datum::timestamp_micros(3)),
(10, Datum::timestamptz_micros(3)),
(
11,
Datum::new(
PrimitiveType::Decimal {
precision: 10,
scale: 5
},
PrimitiveLiteral::Decimal(100)
)
),
(12, Datum::uuid(Uuid::from_u128(3))),
(
13,
Literal::fixed(vec![21, 22, 23, 24, 25, 26, 27, 28, 29, 30])
Datum::fixed(vec![21, 22, 23, 24, 25, 26, 27, 28, 29, 30])
),
])
);