Skip to content

Commit

Permalink
Apply bounds to the provided input statistics when a filter is presen…
Browse files Browse the repository at this point in the history
…t in the plan
  • Loading branch information
gruuya committed Jan 27, 2025
1 parent 36c8789 commit 3aebce8
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 24 deletions.
58 changes: 41 additions & 17 deletions crates/integration_tests/tests/datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use datafusion::assert_batches_eq;
use datafusion::catalog::TableProvider;
use datafusion::common::stats::Precision;
use datafusion::common::{ColumnStatistics, ScalarValue, Statistics};
use datafusion::logical_expr::{col, lit};
use datafusion::prelude::SessionContext;
use iceberg::{Catalog, Result, TableIdent};
use iceberg_datafusion::IcebergTableProvider;
Expand All @@ -38,16 +39,11 @@ async fn test_basic_queries() -> Result<()> {

let table = catalog
.load_table(&TableIdent::from_strs(["default", "types_test"]).unwrap())
.await
.unwrap();
.await?;

let ctx = SessionContext::new();

let table_provider = Arc::new(
IcebergTableProvider::try_new_from_table(table)
.await
.unwrap(),
);
let table_provider = Arc::new(IcebergTableProvider::try_new_from_table(table).await?);

let schema = table_provider.schema();

Expand Down Expand Up @@ -146,22 +142,23 @@ async fn test_statistics() -> Result<()> {

let catalog = fixture.rest_catalog;

// Test table statistics
let table = catalog
.load_table(
&TableIdent::from_strs(["default", "test_positional_merge_on_read_double_deletes"])
.unwrap(),
)
.await
.unwrap();
.load_table(&TableIdent::from_strs([
"default",
"test_positional_merge_on_read_double_deletes",
])?)
.await?;

let stats = IcebergTableProvider::try_new_from_table(table)
let table_provider = IcebergTableProvider::try_new_from_table(table)
.await?
.with_computed_statistics()
.await
.statistics();
.await;

let table_stats = table_provider.statistics();

assert_eq!(
stats,
table_stats,
Some(Statistics {
num_rows: Precision::Inexact(12),
total_byte_size: Precision::Absent,
Expand All @@ -188,5 +185,32 @@ async fn test_statistics() -> Result<()> {
})
);

// Test plan statistics with filtering
let ctx = SessionContext::new();
let scan = table_provider
.scan(
&ctx.state(),
Some(&vec![1]),
&[col("number").gt(lit(4))],
None,
)
.await
.unwrap();

let plan_stats = scan.statistics().unwrap();

// The estimate for the number of rows and the min value for the column are changed in response
// to the filtration
assert_eq!(plan_stats, Statistics {
num_rows: Precision::Inexact(8),
total_byte_size: Precision::Absent,
column_statistics: vec![ColumnStatistics {
null_count: Precision::Inexact(0),
max_value: Precision::Inexact(ScalarValue::Int32(Some(12))),
min_value: Precision::Inexact(ScalarValue::Int32(Some(5))),
distinct_count: Precision::Absent,
},],
});

Ok(())
}
48 changes: 44 additions & 4 deletions crates/integrations/datafusion/src/physical_plan/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,22 @@ use std::vec;

use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
use datafusion::common::Statistics;
use datafusion::common::{Statistics, ToDFSchema};
use datafusion::error::Result as DFResult;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr::EquivalenceProperties;
use datafusion::logical_expr::utils::conjunction;
use datafusion::physical_expr::{create_physical_expr, EquivalenceProperties};
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{DisplayAs, ExecutionPlan, Partitioning, PlanProperties};
use datafusion::prelude::Expr;
use futures::{Stream, TryStreamExt};
use iceberg::expr::Predicate;
use iceberg::table::Table;
use log::warn;

use super::expr_to_predicate::convert_filters_to_predicate;
use crate::to_datafusion_error;
use crate::{apply_bounds, to_datafusion_error};

/// Manages the scanning process of an Iceberg [`Table`], encapsulating the
/// necessary details and computed properties required for execution planning.
Expand Down Expand Up @@ -63,14 +65,35 @@ impl IcebergTableScan {
table: Table,
snapshot_id: Option<i64>,
schema: ArrowSchemaRef,
statistics: Statistics,
statistics: Option<Statistics>,
projection: Option<&Vec<usize>>,
filters: &[Expr],
) -> Self {
let output_schema = match projection {
None => schema.clone(),
Some(projection) => Arc::new(schema.project(projection).unwrap()),
};

let statistics = statistics
.as_ref()
.map(|stats| {
let stats = match projection {
None => stats.clone(),
Some(projection) => stats.clone().project(Some(projection)),
};
Self::bound_statistics(stats.clone(), filters, output_schema.clone())
})
.transpose()
.inspect_err(|err| {
warn!(
"Failed to bound input statistics, defaulting to none: {:?}",
err
)
})
.ok()
.flatten()
.unwrap_or(Statistics::new_unknown(output_schema.as_ref()));

let plan_properties = Self::compute_properties(output_schema.clone());
let projection = get_column_names(schema.clone(), projection);
let predicates = convert_filters_to_predicate(filters);
Expand All @@ -97,6 +120,23 @@ impl IcebergTableScan {
Boundedness::Bounded,
)
}

/// Estimate the effective bounded statistics corresponding to the provided filter expressions
fn bound_statistics(
input_stats: Statistics,
filters: &[Expr],
schema: ArrowSchemaRef,
) -> DFResult<Statistics> {
Ok(if let Some(filters) = conjunction(filters.to_vec()) {
let schema = schema.clone();
let df_schema = schema.clone().to_dfschema()?;
let predicate = create_physical_expr(&filters, &df_schema, &Default::default())?;

apply_bounds(input_stats, &predicate, schema)?
} else {
input_stats
})
}
}

impl ExecutionPlan for IcebergTableScan {
Expand Down
59 changes: 59 additions & 0 deletions crates/integrations/datafusion/src/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
// under the License.

use std::collections::HashMap;
use std::sync::Arc;

use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::stats::Precision;
use datafusion::common::{ColumnStatistics, Statistics};
use datafusion::error::Result as DFResult;
use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries, PhysicalExpr};
use iceberg::spec::{DataContentType, ManifestStatus};
use iceberg::table::Table;
use iceberg::Result;
Expand Down Expand Up @@ -113,3 +117,58 @@ pub async fn compute_statistics(table: &Table, snapshot_id: Option<i64>) -> Resu
column_statistics: col_stats,
})
}

// Apply bounds to the provided input statistics.
//
// Adapted from `FilterExec::statistics_helper` in DataFusion.
pub fn apply_bounds(
input_stats: Statistics,
predicate: &Arc<dyn PhysicalExpr>,
schema: SchemaRef,
) -> DFResult<Statistics> {
let num_rows = input_stats.num_rows;
let total_byte_size = input_stats.total_byte_size;
let input_analysis_ctx =
AnalysisContext::try_from_statistics(&schema, &input_stats.column_statistics)?;

let analysis_ctx = analyze(predicate, input_analysis_ctx, &schema)?;

// Estimate (inexact) selectivity of predicate
let selectivity = analysis_ctx.selectivity.unwrap_or(1.0);
let num_rows = num_rows.with_estimated_selectivity(selectivity);
let total_byte_size = total_byte_size.with_estimated_selectivity(selectivity);

let column_statistics = analysis_ctx
.boundaries
.into_iter()
.enumerate()
.map(
|(
idx,
ExprBoundaries {
interval,
distinct_count,
..
},
)| {
let (lower, upper) = interval.into_bounds();
let (min_value, max_value) = if lower.eq(&upper) {
(Precision::Exact(lower), Precision::Exact(upper))
} else {
(Precision::Inexact(lower), Precision::Inexact(upper))
};
ColumnStatistics {
null_count: input_stats.column_statistics[idx].null_count.to_inexact(),
max_value,
min_value,
distinct_count: distinct_count.to_inexact(),
}
},
)
.collect();
Ok(Statistics {
num_rows,
total_byte_size,
column_statistics,
})
}
4 changes: 1 addition & 3 deletions crates/integrations/datafusion/src/table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,7 @@ impl TableProvider for IcebergTableProvider {
self.table.clone(),
self.snapshot_id,
self.schema.clone(),
self.statistics
.clone()
.unwrap_or(Statistics::new_unknown(self.schema.as_ref())),
self.statistics.clone(),
projection,
filters,
)))
Expand Down

0 comments on commit 3aebce8

Please sign in to comment.