diff --git a/crates/integration_tests/tests/datafusion.rs b/crates/integration_tests/tests/datafusion.rs index 9eeae5e6b..93fb1f679 100644 --- a/crates/integration_tests/tests/datafusion.rs +++ b/crates/integration_tests/tests/datafusion.rs @@ -24,6 +24,7 @@ use datafusion::assert_batches_eq; use datafusion::catalog::TableProvider; use datafusion::common::stats::Precision; use datafusion::common::{ColumnStatistics, ScalarValue, Statistics}; +use datafusion::logical_expr::{col, lit}; use datafusion::prelude::SessionContext; use iceberg::{Catalog, Result, TableIdent}; use iceberg_datafusion::IcebergTableProvider; @@ -38,16 +39,11 @@ async fn test_basic_queries() -> Result<()> { let table = catalog .load_table(&TableIdent::from_strs(["default", "types_test"]).unwrap()) - .await - .unwrap(); + .await?; let ctx = SessionContext::new(); - let table_provider = Arc::new( - IcebergTableProvider::try_new_from_table(table) - .await - .unwrap(), - ); + let table_provider = Arc::new(IcebergTableProvider::try_new_from_table(table).await?); let schema = table_provider.schema(); @@ -146,22 +142,23 @@ async fn test_statistics() -> Result<()> { let catalog = fixture.rest_catalog; + // Test table statistics let table = catalog - .load_table( - &TableIdent::from_strs(["default", "test_positional_merge_on_read_double_deletes"]) - .unwrap(), - ) - .await - .unwrap(); + .load_table(&TableIdent::from_strs([ + "default", + "test_positional_merge_on_read_double_deletes", + ])?) + .await?; - let stats = IcebergTableProvider::try_new_from_table(table) + let table_provider = IcebergTableProvider::try_new_from_table(table) .await? .with_computed_statistics() - .await - .statistics(); + .await; + + let table_stats = table_provider.statistics(); assert_eq!( - stats, + table_stats, Some(Statistics { num_rows: Precision::Inexact(12), total_byte_size: Precision::Absent, @@ -188,5 +185,32 @@ async fn test_statistics() -> Result<()> { }) ); + // Test plan statistics with filtering + let ctx = SessionContext::new(); + let scan = table_provider + .scan( + &ctx.state(), + Some(&vec![1]), + &[col("number").gt(lit(4))], + None, + ) + .await + .unwrap(); + + let plan_stats = scan.statistics().unwrap(); + + // The estimate for the number of rows and the min value for the column are changed in response + // to the filtration + assert_eq!(plan_stats, Statistics { + num_rows: Precision::Inexact(8), + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(0), + max_value: Precision::Inexact(ScalarValue::Int32(Some(12))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(5))), + distinct_count: Precision::Absent, + },], + }); + Ok(()) } diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index e4bfd42e3..6a88c8b62 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -22,10 +22,11 @@ use std::vec; use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; -use datafusion::common::Statistics; +use datafusion::common::{Statistics, ToDFSchema}; use datafusion::error::Result as DFResult; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; -use datafusion::physical_expr::EquivalenceProperties; +use datafusion::logical_expr::utils::conjunction; +use datafusion::physical_expr::{create_physical_expr, EquivalenceProperties}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{DisplayAs, ExecutionPlan, Partitioning, PlanProperties}; @@ -33,9 +34,10 @@ use datafusion::prelude::Expr; use futures::{Stream, TryStreamExt}; use iceberg::expr::Predicate; use iceberg::table::Table; +use log::warn; use super::expr_to_predicate::convert_filters_to_predicate; -use crate::to_datafusion_error; +use crate::{apply_bounds, to_datafusion_error}; /// Manages the scanning process of an Iceberg [`Table`], encapsulating the /// necessary details and computed properties required for execution planning. @@ -63,7 +65,7 @@ impl IcebergTableScan { table: Table, snapshot_id: Option, schema: ArrowSchemaRef, - statistics: Statistics, + statistics: Option, projection: Option<&Vec>, filters: &[Expr], ) -> Self { @@ -71,6 +73,26 @@ impl IcebergTableScan { None => schema.clone(), Some(projection) => Arc::new(schema.project(projection).unwrap()), }; + + let statistics = statistics + .map(|stats| { + let stats = match projection { + None => stats, + Some(projection) => stats.project(Some(projection)), + }; + Self::bound_statistics(stats.clone(), filters, output_schema.clone()) + }) + .transpose() + .inspect_err(|err| { + warn!( + "Failed to bound input statistics, defaulting to none: {:?}", + err + ) + }) + .ok() + .flatten() + .unwrap_or(Statistics::new_unknown(output_schema.as_ref())); + let plan_properties = Self::compute_properties(output_schema.clone()); let projection = get_column_names(schema.clone(), projection); let predicates = convert_filters_to_predicate(filters); @@ -97,6 +119,23 @@ impl IcebergTableScan { Boundedness::Bounded, ) } + + /// Estimate the effective bounded statistics corresponding to the provided filter expressions + fn bound_statistics( + input_stats: Statistics, + filters: &[Expr], + schema: ArrowSchemaRef, + ) -> DFResult { + Ok(if let Some(filters) = conjunction(filters.to_vec()) { + let schema = schema.clone(); + let df_schema = schema.clone().to_dfschema()?; + let predicate = create_physical_expr(&filters, &df_schema, &Default::default())?; + + apply_bounds(input_stats, &predicate, schema)? + } else { + input_stats + }) + } } impl ExecutionPlan for IcebergTableScan { diff --git a/crates/integrations/datafusion/src/statistics.rs b/crates/integrations/datafusion/src/statistics.rs index c54dcdfe9..0a7db6ff4 100644 --- a/crates/integrations/datafusion/src/statistics.rs +++ b/crates/integrations/datafusion/src/statistics.rs @@ -16,9 +16,13 @@ // under the License. use std::collections::HashMap; +use std::sync::Arc; +use datafusion::arrow::datatypes::SchemaRef; use datafusion::common::stats::Precision; use datafusion::common::{ColumnStatistics, Statistics}; +use datafusion::error::Result as DFResult; +use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries, PhysicalExpr}; use iceberg::spec::{DataContentType, ManifestStatus}; use iceberg::table::Table; use iceberg::Result; @@ -113,3 +117,58 @@ pub async fn compute_statistics(table: &Table, snapshot_id: Option) -> Resu column_statistics: col_stats, }) } + +// Apply bounds to the provided input statistics. +// +// Adapted from `FilterExec::statistics_helper` in DataFusion. +pub fn apply_bounds( + input_stats: Statistics, + predicate: &Arc, + schema: SchemaRef, +) -> DFResult { + let num_rows = input_stats.num_rows; + let total_byte_size = input_stats.total_byte_size; + let input_analysis_ctx = + AnalysisContext::try_from_statistics(&schema, &input_stats.column_statistics)?; + + let analysis_ctx = analyze(predicate, input_analysis_ctx, &schema)?; + + // Estimate (inexact) selectivity of predicate + let selectivity = analysis_ctx.selectivity.unwrap_or(1.0); + let num_rows = num_rows.with_estimated_selectivity(selectivity); + let total_byte_size = total_byte_size.with_estimated_selectivity(selectivity); + + let column_statistics = analysis_ctx + .boundaries + .into_iter() + .enumerate() + .map( + |( + idx, + ExprBoundaries { + interval, + distinct_count, + .. + }, + )| { + let (lower, upper) = interval.into_bounds(); + let (min_value, max_value) = if lower.eq(&upper) { + (Precision::Exact(lower), Precision::Exact(upper)) + } else { + (Precision::Inexact(lower), Precision::Inexact(upper)) + }; + ColumnStatistics { + null_count: input_stats.column_statistics[idx].null_count.to_inexact(), + max_value, + min_value, + distinct_count: distinct_count.to_inexact(), + } + }, + ) + .collect(); + Ok(Statistics { + num_rows, + total_byte_size, + column_statistics, + }) +} diff --git a/crates/integrations/datafusion/src/table/mod.rs b/crates/integrations/datafusion/src/table/mod.rs index 7e801a48a..8d7336ccf 100644 --- a/crates/integrations/datafusion/src/table/mod.rs +++ b/crates/integrations/datafusion/src/table/mod.rs @@ -152,9 +152,7 @@ impl TableProvider for IcebergTableProvider { self.table.clone(), self.snapshot_id, self.schema.clone(), - self.statistics - .clone() - .unwrap_or(Statistics::new_unknown(self.schema.as_ref())), + self.statistics.clone(), projection, filters, )))