From b7b88cd9a2cad9cda2b29f8f6343611185465902 Mon Sep 17 00:00:00 2001 From: Satyam Singh Date: Mon, 7 Apr 2025 16:31:21 +0530 Subject: [PATCH 1/5] Rework sql implementation to add SQLTable trait and support for parameterized views This PR primarily aims to refactor the SQL implementation and additionally extends the traits and their default implementation to support storing function arguments for table. - Adds a SQLTable trait that is to be used within SQLTableSource. SQLTable trait abstracts information about the remote table and allows of the trait to hook into the final stages where they can change logical plan and the final AST for sql query that is being federated via VirtualExecutionPlan. - Adds RemoteTable, a default implementation for SQLTable trait, capable of handing table and parameterized views. - Adds RemoteTableRef, a extention to default TableReference capable of storing function args. - Provides a default AST Analyzer for rewriting Statement for tables which contain RemoteTableRef with some functional Args - Extends SqlExecutor trait with logical_optimizer method, this can allow executor to hook into federation planning, allowing for rewriting LogicalPlan and even placement of FederationPlanNode. This is useful for avoiding federating nodes that are only part of datafusion eg. UDF, UDAF.. etc. - Refactors and testing related to usage of this feature --- .gitignore | 3 +- datafusion-federation/src/sql/analyzer.rs | 883 +++++++++++++ datafusion-federation/src/sql/ast_analyzer.rs | 111 ++ datafusion-federation/src/sql/executor.rs | 15 +- datafusion-federation/src/sql/mod.rs | 1138 +++++------------ datafusion-federation/src/sql/schema.rs | 152 ++- datafusion-federation/src/sql/table.rs | 170 +++ .../src/sql/table_reference.rs | 280 ++++ 8 files changed, 1856 insertions(+), 896 deletions(-) create mode 100644 datafusion-federation/src/sql/analyzer.rs create mode 100644 datafusion-federation/src/sql/ast_analyzer.rs create mode 100644 datafusion-federation/src/sql/table.rs create mode 100644 datafusion-federation/src/sql/table_reference.rs diff --git a/.gitignore b/.gitignore index c92a89f..8971e72 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ /node_modules package-lock.json package.json -.DS_Store \ No newline at end of file +.DS_Store +.cargo \ No newline at end of file diff --git a/datafusion-federation/src/sql/analyzer.rs b/datafusion-federation/src/sql/analyzer.rs new file mode 100644 index 0000000..1cfb0e8 --- /dev/null +++ b/datafusion-federation/src/sql/analyzer.rs @@ -0,0 +1,883 @@ +use std::{collections::HashMap, sync::Arc}; + +use datafusion::{ + common::Column, + logical_expr::{ + expr::{ + AggregateFunction, AggregateFunctionParams, Alias, Exists, InList, InSubquery, + PlannedReplaceSelectItem, ScalarFunction, Sort, Unnest, WildcardOptions, + WindowFunction, WindowFunctionParams, + }, + Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Limit, LogicalPlan, Subquery, + TryCast, + }, + sql::TableReference, +}; + +use crate::get_table_source; + +use super::SQLTableSource; + +type Result = std::result::Result; + +/// Rewrite LogicalPlan's table scans and expressions to use the federated table name. +#[derive(Debug)] +pub struct RewriteTableScanAnalyzer; + +impl RewriteTableScanAnalyzer { + pub fn rewrite(plan: LogicalPlan) -> Result { + let known_rewrites = &mut HashMap::new(); + rewrite_table_scans(&plan, known_rewrites) + } +} + +/// Rewrite table scans to use the original federated table name. +fn rewrite_table_scans( + plan: &LogicalPlan, + known_rewrites: &mut HashMap, +) -> Result { + if plan.inputs().is_empty() { + if let LogicalPlan::TableScan(table_scan) = plan { + let original_table_name = table_scan.table_name.clone(); + let mut new_table_scan = table_scan.clone(); + + let Some(federated_source) = get_table_source(&table_scan.source)? else { + // Not a federated source + return Ok(plan.clone()); + }; + + match federated_source.as_any().downcast_ref::() { + Some(sql_table_source) => { + let remote_table_name = sql_table_source.table_reference(); + known_rewrites.insert(original_table_name, remote_table_name.clone()); + + // Rewrite the schema of this node to have the remote table as the qualifier. + let new_schema = (*new_table_scan.projected_schema) + .clone() + .replace_qualifier(remote_table_name.clone()); + new_table_scan.projected_schema = Arc::new(new_schema); + new_table_scan.table_name = remote_table_name; + } + None => { + // Not a SQLTableSource (is this possible?) + return Ok(plan.clone()); + } + } + + return Ok(LogicalPlan::TableScan(new_table_scan)); + } else { + return Ok(plan.clone()); + } + } + + if let LogicalPlan::Limit(limit) = plan { + let rewritten_skip = limit + .skip + .as_ref() + .map(|skip| rewrite_table_scans_in_expr(*skip.clone(), known_rewrites).map(Box::new)) + .transpose()?; + + let rewritten_fetch = limit + .fetch + .as_ref() + .map(|fetch| rewrite_table_scans_in_expr(*fetch.clone(), known_rewrites).map(Box::new)) + .transpose()?; + + // explicitly set fetch and skip + let new_plan = LogicalPlan::Limit(Limit { + skip: rewritten_skip, + fetch: rewritten_fetch, + input: Arc::new(rewrite_table_scans(&limit.input, known_rewrites)?), + }); + + return Ok(new_plan); + } + + let rewritten_inputs = plan + .inputs() + .into_iter() + .map(|plan| rewrite_table_scans(plan, known_rewrites)) + .collect::>>()?; + + let mut new_expressions = vec![]; + for expression in plan.expressions() { + let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?; + new_expressions.push(new_expr); + } + + let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; + + Ok(new_plan) +} + +// The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite. +// The name to rewrite should NOT be a substring of another name. +// Supports multiple occurrences of table_ref_str in col_name. +pub fn rewrite_column_name_in_expr( + col_name: &str, + table_ref_str: &str, + rewrite: &str, + start_pos: usize, +) -> Option { + if start_pos >= col_name.len() { + return None; + } + + // Find the first occurrence of table_ref_str starting from start_pos + let idx = col_name[start_pos..].find(table_ref_str)?; + + // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos + let idx = start_pos + idx; + + if idx > 0 { + // Check if the previous character is alphabetic, numeric, underscore or period, in which case we + // should not rewrite as it is a part of another name. + if let Some(prev_char) = col_name.chars().nth(idx - 1) { + if prev_char.is_alphabetic() + || prev_char.is_numeric() + || prev_char == '_' + || prev_char == '.' + { + return rewrite_column_name_in_expr( + col_name, + table_ref_str, + rewrite, + idx + table_ref_str.len(), + ); + } + } + } + + // Check if the next character is alphabetic, numeric or underscore, in which case we + // should not rewrite as it is a part of another name. + if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) { + if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' { + return rewrite_column_name_in_expr( + col_name, + table_ref_str, + rewrite, + idx + table_ref_str.len(), + ); + } + } + + // Found full match, replace table_ref_str occurrence with rewrite + let rewritten_name = format!( + "{}{}{}", + &col_name[..idx], + rewrite, + &col_name[idx + table_ref_str.len()..] + ); + // Check if the rewritten name contains more occurrence of table_ref_str, and rewrite them as well + // This is done by providing the updated start_pos for search + match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len()) + { + Some(new_name) => Some(new_name), // more occurrences found + None => Some(rewritten_name), // no more occurrences/changes + } +} + +fn rewrite_table_scans_in_expr( + expr: Expr, + known_rewrites: &mut HashMap, +) -> Result { + match expr { + Expr::ScalarSubquery(subquery) => { + let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?; + let outer_ref_columns = subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(new_subquery), + outer_ref_columns, + })) + } + Expr::BinaryExpr(binary_expr) => { + let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?; + let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?; + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + binary_expr.op, + Box::new(right), + ))) + } + Expr::Column(mut col) => { + if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { + Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name))) + } else { + // This prevent over-eager rewrite and only pass the column into below rewritten + // rule like MAX(...) + if col.relation.is_some() { + return Ok(Expr::Column(col)); + } + + // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. + // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" + let (new_name, was_rewritten) = known_rewrites.iter().fold( + (col.name.to_string(), false), + |(col_name, was_rewritten), (table_ref, rewrite)| { + match rewrite_column_name_in_expr( + &col_name, + &table_ref.to_string(), + &rewrite.to_string(), + 0, + ) { + Some(new_name) => (new_name, true), + None => (col_name, was_rewritten), + } + }, + ); + if was_rewritten { + Ok(Expr::Column(Column::new(col.relation.take(), new_name))) + } else { + Ok(Expr::Column(col)) + } + } + } + Expr::Alias(alias) => { + let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?; + if let Some(relation) = &alias.relation { + if let Some(rewrite) = known_rewrites.get(relation) { + return Ok(Expr::Alias(Alias::new( + expr, + Some(rewrite.clone()), + alias.name, + ))); + } + } + Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name))) + } + Expr::Like(like) => { + let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?; + let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?; + Ok(Expr::Like(Like::new( + like.negated, + Box::new(expr), + Box::new(pattern), + like.escape_char, + like.case_insensitive, + ))) + } + Expr::SimilarTo(similar_to) => { + let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?; + let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?; + Ok(Expr::SimilarTo(Like::new( + similar_to.negated, + Box::new(expr), + Box::new(pattern), + similar_to.escape_char, + similar_to.case_insensitive, + ))) + } + Expr::Not(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::Not(Box::new(expr))) + } + Expr::IsNotNull(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotNull(Box::new(expr))) + } + Expr::IsNull(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNull(Box::new(expr))) + } + Expr::IsTrue(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsTrue(Box::new(expr))) + } + Expr::IsFalse(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsFalse(Box::new(expr))) + } + Expr::IsUnknown(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsUnknown(Box::new(expr))) + } + Expr::IsNotTrue(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotTrue(Box::new(expr))) + } + Expr::IsNotFalse(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotFalse(Box::new(expr))) + } + Expr::IsNotUnknown(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotUnknown(Box::new(expr))) + } + Expr::Negative(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::Negative(Box::new(expr))) + } + Expr::Between(between) => { + let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?; + let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?; + let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?; + Ok(Expr::Between(Between::new( + Box::new(expr), + between.negated, + Box::new(low), + Box::new(high), + ))) + } + Expr::Case(case) => { + let expr = case + .expr + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let else_expr = case + .else_expr + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let when_expr = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + let when = rewrite_table_scans_in_expr(*when, known_rewrites); + let then = rewrite_table_scans_in_expr(*then, known_rewrites); + + match (when, then) { + (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))), + (Err(e), _) | (_, Err(e)) => Err(e), + } + }) + .collect::, Box)>>>()?; + Ok(Expr::Case(Case::new(expr, when_expr, else_expr))) + } + Expr::Cast(cast) => { + let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?; + Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type))) + } + Expr::TryCast(try_cast) => { + let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?; + Ok(Expr::TryCast(TryCast::new( + Box::new(expr), + try_cast.data_type, + ))) + } + Expr::ScalarFunction(sf) => { + let args = sf + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::ScalarFunction(ScalarFunction { + func: sf.func, + args, + })) + } + Expr::AggregateFunction(af) => { + let args = af + .params + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let filter = af + .params + .filter + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let order_by = af + .params + .order_by + .map(|e| { + e.into_iter() + .map(|sort| { + Ok(Sort { + expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?, + ..sort + }) + }) + .collect::>>() + }) + .transpose()?; + let params = AggregateFunctionParams { + args, + distinct: af.params.distinct, + filter, + order_by, + null_treatment: af.params.null_treatment, + }; + Ok(Expr::AggregateFunction(AggregateFunction { + func: af.func, + params, + })) + } + Expr::WindowFunction(wf) => { + let args = wf + .params + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let partition_by = wf + .params + .partition_by + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let order_by = wf + .params + .order_by + .into_iter() + .map(|sort| { + Ok(Sort { + expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?, + ..sort + }) + }) + .collect::>>()?; + let params = WindowFunctionParams { + args, + partition_by, + order_by, + window_frame: wf.params.window_frame, + null_treatment: wf.params.null_treatment, + }; + Ok(Expr::WindowFunction(WindowFunction { + fun: wf.fun, + params, + })) + } + Expr::InList(il) => { + let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?; + let list = il + .list + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated))) + } + Expr::Exists(exists) => { + let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?; + let outer_ref_columns = exists + .subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let subquery = Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns, + }; + Ok(Expr::Exists(Exists::new(subquery, exists.negated))) + } + Expr::InSubquery(is) => { + let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?; + let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?; + let outer_ref_columns = is + .subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let subquery = Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns, + }; + Ok(Expr::InSubquery(InSubquery::new( + Box::new(expr), + subquery, + is.negated, + ))) + } + // TODO: remove the next line after `Expr::Wildcard` is removed in datafusion + #[expect(deprecated)] + Expr::Wildcard { qualifier, options } => { + let options = WildcardOptions { + replace: options + .replace + .map(|replace| -> Result { + Ok(PlannedReplaceSelectItem { + planned_expressions: replace + .planned_expressions + .into_iter() + .map(|expr| rewrite_table_scans_in_expr(expr, known_rewrites)) + .collect::>>()?, + ..replace + }) + }) + .transpose()?, + ..*options + }; + if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) { + Ok(Expr::Wildcard { + qualifier: Some(rewrite.clone()), + options: Box::new(options), + }) + } else { + Ok(Expr::Wildcard { + qualifier, + options: Box::new(options), + }) + } + } + Expr::GroupingSet(gs) => match gs { + GroupingSet::Rollup(exprs) => { + let exprs = exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs))) + } + GroupingSet::Cube(exprs) => { + let exprs = exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::GroupingSet(GroupingSet::Cube(exprs))) + } + GroupingSet::GroupingSets(vec_exprs) => { + let vec_exprs = vec_exprs + .into_iter() + .map(|exprs| { + exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>() + }) + .collect::>>>()?; + Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs))) + } + }, + Expr::OuterReferenceColumn(dt, col) => { + if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { + Ok(Expr::OuterReferenceColumn( + dt, + Column::new(Some(rewrite.clone()), &col.name), + )) + } else { + Ok(Expr::OuterReferenceColumn(dt, col)) + } + } + Expr::Unnest(unnest) => { + let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?; + Ok(Expr::Unnest(Unnest::new(expr))) + } + Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr), + } +} + +#[cfg(test)] +mod tests { + use crate::sql::table::SQLTable; + use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource}; + use crate::FederatedTableProviderAdaptor; + use async_trait::async_trait; + use datafusion::arrow::datatypes::{Schema, SchemaRef}; + use datafusion::execution::SendableRecordBatchStream; + use datafusion::sql::unparser::dialect::Dialect; + use datafusion::sql::unparser::plan_to_sql; + use datafusion::{ + arrow::datatypes::{DataType, Field}, + catalog::{MemorySchemaProvider, SchemaProvider}, + common::Column, + datasource::{DefaultTableSource, TableProvider}, + execution::context::SessionContext, + logical_expr::LogicalPlanBuilder, + prelude::Expr, + }; + + use super::*; + + struct TestExecutor; + + #[async_trait] + impl SQLExecutor for TestExecutor { + fn name(&self) -> &str { + "TestExecutor" + } + + fn compute_context(&self) -> Option { + None + } + + fn dialect(&self) -> Arc { + unimplemented!() + } + + fn execute(&self, _query: &str, _schema: SchemaRef) -> Result { + unimplemented!() + } + + async fn table_names(&self) -> Result> { + unimplemented!() + } + + async fn get_table_schema(&self, _table_name: &str) -> Result { + unimplemented!() + } + } + + #[derive(Debug)] + struct TestTable { + name: RemoteTableRef, + schema: SchemaRef, + } + + impl TestTable { + fn new(name: String, schema: SchemaRef) -> Self { + TestTable { + name: name.try_into().unwrap(), + schema, + } + } + } + + impl SQLTable for TestTable { + fn table_reference(&self) -> TableReference { + TableReference::from(&self.name) + } + + fn schema(&self) -> datafusion::arrow::datatypes::SchemaRef { + self.schema.clone() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + } + + fn get_test_table_provider() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Date32, false), + ])); + let table = Arc::new(TestTable::new("remote_table".to_string(), schema)); + let provider = Arc::new(SQLFederationProvider::new(Arc::new(TestExecutor))); + let table_source = Arc::new(SQLTableSource { provider, table }); + Arc::new(FederatedTableProviderAdaptor::new(table_source)) + } + + fn get_test_table_source() -> Arc { + Arc::new(DefaultTableSource::new(get_test_table_provider())) + } + + fn get_test_df_context() -> SessionContext { + let ctx = SessionContext::new(); + let catalog = ctx + .catalog("datafusion") + .expect("default catalog is datafusion"); + let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc; + catalog + .register_schema("foo", Arc::clone(&foo_schema)) + .expect("to register schema"); + foo_schema + .register_table("df_table".to_string(), get_test_table_provider()) + .expect("to register table"); + + let public_schema = catalog + .schema("public") + .expect("public schema should exist"); + public_schema + .register_table("app_table".to_string(), get_test_table_provider()) + .expect("to register table"); + + ctx + } + + #[test] + fn test_rewrite_table_scans_basic() -> Result<()> { + let plan = LogicalPlanBuilder::scan("foo.df_table", get_test_table_source(), None)? + .project(vec![ + Expr::Column(Column::from_qualified_name("foo.df_table.a")), + Expr::Column(Column::from_qualified_name("foo.df_table.b")), + Expr::Column(Column::from_qualified_name("foo.df_table.c")), + ])? + .build()?; + + let rewritten_plan = RewriteTableScanAnalyzer::rewrite(plan)?; + + println!("rewritten_plan: \n{:#?}", rewritten_plan); + let unparsed_sql = plan_to_sql(&rewritten_plan)?; + + println!("unparsed_sql: \n{unparsed_sql}"); + + assert_eq!( + format!("{unparsed_sql}"), + r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"# + ); + + Ok(()) + } + + fn init_tracing() { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter("debug") + .with_ansi(true) + .finish(); + let _ = tracing::subscriber::set_global_default(subscriber); + } + + #[tokio::test] + async fn test_rewrite_table_scans_agg() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let agg_tests = vec![ + ( + "SELECT MAX(a) FROM foo.df_table", + r#"SELECT max(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT foo.df_table.a FROM foo.df_table", + r#"SELECT remote_table.a FROM remote_table"#, + ), + ( + "SELECT MIN(a) FROM foo.df_table", + r#"SELECT min(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT AVG(a) FROM foo.df_table", + r#"SELECT avg(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT SUM(a) FROM foo.df_table", + r#"SELECT sum(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT COUNT(a) FROM foo.df_table", + r#"SELECT count(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, + ), + ( + "SELECT app_table from (SELECT a as app_table FROM app_table) b", + r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, + ), + ( + "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b", + r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, + ), + // multiple occurrences of the same table in single aggregation expression + ( + "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table", + r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#, + ), + // different tables in single aggregation expression + ( + "SELECT COUNT(CASE WHEN appt.a > 0 THEN appt.a ELSE dft.a END) FROM app_table as appt, foo.df_table as dft", + "SELECT count(CASE WHEN (appt.a > 0) THEN appt.a ELSE dft.a END) FROM remote_table AS appt CROSS JOIN remote_table AS dft" + ), + ]; + + for test in agg_tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_rewrite_table_scans_alias() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + ( + "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, + ), + ( + "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, + ), + ( + "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)", + r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } + + async fn test_sql(ctx: &SessionContext, sql_query: &str, expected_sql: &str) -> Result<()> { + let data_frame = ctx.sql(sql_query).await?; + + println!("before optimization: \n{:#?}", data_frame.logical_plan()); + + let rewritten_plan = RewriteTableScanAnalyzer::rewrite(data_frame.logical_plan().clone())?; + + println!("rewritten_plan: \n{:#?}", rewritten_plan); + + let unparsed_sql = plan_to_sql(&rewritten_plan)?; + + println!("unparsed_sql: \n{unparsed_sql}"); + + assert_eq!( + format!("{unparsed_sql}"), + expected_sql, + "SQL under test: {}", + sql_query + ); + + Ok(()) + } + + #[tokio::test] + async fn test_rewrite_table_scans_limit_offset() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + // Basic LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 5", + r#"SELECT remote_table.a FROM remote_table LIMIT 5"#, + ), + // Basic OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 5", + r#"SELECT remote_table.a FROM remote_table OFFSET 5"#, + ), + // OFFSET after LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5", + r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, + ), + // LIMIT after OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10", + r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, + ), + // Zero OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 0", + r#"SELECT remote_table.a FROM remote_table OFFSET 0"#, + ), + // Zero LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 0", + r#"SELECT remote_table.a FROM remote_table LIMIT 0"#, + ), + // Zero LIMIT and OFFSET + ( + "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0", + r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } +} diff --git a/datafusion-federation/src/sql/ast_analyzer.rs b/datafusion-federation/src/sql/ast_analyzer.rs new file mode 100644 index 0000000..e7636e3 --- /dev/null +++ b/datafusion-federation/src/sql/ast_analyzer.rs @@ -0,0 +1,111 @@ +use std::ops::ControlFlow; + +use datafusion::sql::{ + sqlparser::ast::{ + FunctionArg, ObjectName, Statement, TableFactor, TableFunctionArgs, VisitMut, VisitorMut, + }, + TableReference, +}; + +use super::AstAnalyzer; + +pub fn replace_table_args_analyzer(mut visitor: TableArgReplace) -> AstAnalyzer { + let x = move |mut statement: Statement| { + VisitMut::visit(&mut statement, &mut visitor); + Ok(statement) + }; + Box::new(x) +} + +/// Used to construct a AstAnalyzer that can replace table arguments. +/// +/// ```rust +/// use datafusion::sql::sqlparser::ast::{FunctionArg, Expr, Value}; +/// use datafusion::sql::TableReference; +/// use datafusion_federation::sql::ast_analyzer::TableArgReplace; +/// +/// let mut analyzer = TableArgReplace::default().with( +/// TableReference::parse_str("table1"), +/// vec![FunctionArg::Unnamed( +/// Expr::Value( +/// Value::Number("1".to_string(), false), +/// ) +/// .into(), +/// )], +/// ); +/// let analyzer = analyzer.into_analyzer(); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct TableArgReplace { + pub tables: Vec<(TableReference, TableFunctionArgs)>, +} + +impl TableArgReplace { + /// Constructs a new `TableArgReplace` instance. + pub fn new(tables: Vec<(TableReference, Vec)>) -> Self { + Self { + tables: tables + .into_iter() + .map(|(table, args)| { + ( + table, + TableFunctionArgs { + args, + settings: None, + }, + ) + }) + .collect(), + } + } + + /// Adds a new table argument replacement. + pub fn with(mut self, table: TableReference, args: Vec) -> Self { + self.tables.push(( + table, + TableFunctionArgs { + args, + settings: None, + }, + )); + self + } + + /// Converts the `TableArgReplace` instance into an `AstAnalyzer`. + pub fn into_analyzer(self) -> AstAnalyzer { + replace_table_args_analyzer(self) + } +} + +impl VisitorMut for TableArgReplace { + type Break = (); + fn pre_visit_table_factor( + &mut self, + table_factor: &mut TableFactor, + ) -> ControlFlow { + if let TableFactor::Table { name, args, .. } = table_factor { + let name_as_tableref = name_to_table_reference(name); + if let Some(arg) = self + .tables + .iter() + .find(|(t, _)| t.resolved_eq(&name_as_tableref)) + { + *args = Some(arg.1.clone()); + } + } + ControlFlow::Continue(()) + } +} + +fn name_to_table_reference(name: &ObjectName) -> TableReference { + let first = name.0.first().map(|n| n.value.to_string()); + let second = name.0.get(1).map(|n| n.value.to_string()); + let third = name.0.get(2).map(|n| n.value.to_string()); + + match (first, second, third) { + (Some(first), Some(second), Some(third)) => TableReference::full(first, second, third), + (Some(first), Some(second), None) => TableReference::partial(first, second), + (Some(first), None, None) => TableReference::bare(first), + _ => panic!("Invalid table name"), + } +} diff --git a/datafusion-federation/src/sql/executor.rs b/datafusion-federation/src/sql/executor.rs index ca04989..e45c6f6 100644 --- a/datafusion-federation/src/sql/executor.rs +++ b/datafusion-federation/src/sql/executor.rs @@ -1,13 +1,17 @@ use async_trait::async_trait; use core::fmt; use datafusion::{ - arrow::datatypes::SchemaRef, error::Result, physical_plan::SendableRecordBatchStream, - sql::sqlparser::ast, sql::unparser::dialect::Dialect, + arrow::datatypes::SchemaRef, + error::Result, + logical_expr::LogicalPlan, + physical_plan::SendableRecordBatchStream, + sql::{sqlparser::ast, unparser::dialect::Dialect}, }; use std::sync::Arc; pub type SQLExecutorRef = Arc; -pub type AstAnalyzer = Box Result>; +pub type AstAnalyzer = Box Result>; +pub type LogicalOptimizer = Box Result>; #[async_trait] pub trait SQLExecutor: Sync + Send { @@ -26,6 +30,11 @@ pub trait SQLExecutor: Sync + Send { /// The specific SQL dialect (currently supports 'sqlite', 'postgres', 'flight') fn dialect(&self) -> Arc; + /// Returns the analyzer rule specific for this engine to modify the logical plan before execution + fn logical_optimizer(&self) -> Option { + None + } + /// Returns an AST analyzer specific for this engine to modify the AST before execution fn ast_analyzer(&self) -> Option { None diff --git a/datafusion-federation/src/sql/mod.rs b/datafusion-federation/src/sql/mod.rs index d6847a3..7d74f19 100644 --- a/datafusion-federation/src/sql/mod.rs +++ b/datafusion-federation/src/sql/mod.rs @@ -1,23 +1,20 @@ +mod analyzer; +pub mod ast_analyzer; mod executor; mod schema; +mod table; +mod table_reference; -use std::{any::Any, collections::HashMap, fmt, sync::Arc, vec}; +use std::{any::Any, fmt, sync::Arc, vec}; +use analyzer::RewriteTableScanAnalyzer; use async_trait::async_trait; use datafusion::{ arrow::datatypes::{Schema, SchemaRef}, - common::{tree_node::Transformed, Column}, - error::Result, + common::tree_node::{Transformed, TreeNode}, + error::{DataFusionError, Result}, execution::{context::SessionState, TaskContext}, - logical_expr::{ - expr::{ - AggregateFunction, AggregateFunctionParams, Alias, Exists, InList, InSubquery, - PlannedReplaceSelectItem, ScalarFunction, Sort, Unnest, WildcardOptions, - WindowFunction, WindowFunctionParams, - }, - Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, Limit, LogicalPlan, - Subquery, TryCast, - }, + logical_expr::{Extension, LogicalPlan}, optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule}, physical_expr::EquivalenceProperties, physical_plan::{ @@ -25,23 +22,18 @@ use datafusion::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, }, - sql::{ - sqlparser::ast::Statement, - unparser::{plan_to_sql, Unparser}, - TableReference, - }, + sql::{sqlparser::ast::Statement, unparser::Unparser}, }; -pub use executor::{AstAnalyzer, SQLExecutor, SQLExecutorRef}; -pub use schema::{MultiSchemaProvider, SQLSchemaProvider, SQLTableSource}; +pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef}; +pub use schema::{MultiSchemaProvider, SQLSchemaProvider}; +pub use table::{RemoteTable, SQLTableSource}; +pub use table_reference::RemoteTableRef; use crate::{ get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider, }; -// #[macro_use] -// extern crate derive_builder; - // SQLFederationProvider provides federation to SQL DMBSs. #[derive(Debug)] pub struct SQLFederationProvider { @@ -76,7 +68,7 @@ impl FederationProvider for SQLFederationProvider { #[derive(Debug)] struct SQLFederationOptimizerRule { - planner: Arc, + planner: Arc, } impl SQLFederationOptimizerRule { @@ -104,12 +96,18 @@ impl OptimizerRule for SQLFederationOptimizerRule { return Ok(Transformed::no(plan)); } } - // Simply accept the entire plan for now + let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone()); let ext_node = Extension { node: Arc::new(fed_plan), }; - Ok(Transformed::yes(LogicalPlan::Extension(ext_node))) + + let mut plan = LogicalPlan::Extension(ext_node); + if let Some(mut rewriter) = self.planner.executor.logical_optimizer() { + plan = rewriter(plan)?; + } + + Ok(Transformed::yes(plan)) } /// A human readable name for this analyzer rule @@ -123,539 +121,7 @@ impl OptimizerRule for SQLFederationOptimizerRule { } } -/// Rewrite table scans to use the original federated table name. -fn rewrite_table_scans( - plan: &LogicalPlan, - known_rewrites: &mut HashMap, -) -> Result { - if plan.inputs().is_empty() { - if let LogicalPlan::TableScan(table_scan) = plan { - let original_table_name = table_scan.table_name.clone(); - let mut new_table_scan = table_scan.clone(); - - let Some(federated_source) = get_table_source(&table_scan.source)? else { - // Not a federated source - return Ok(plan.clone()); - }; - - match federated_source.as_any().downcast_ref::() { - Some(sql_table_source) => { - let remote_table_name = TableReference::from(sql_table_source.table_name()); - known_rewrites.insert(original_table_name, remote_table_name.clone()); - - // Rewrite the schema of this node to have the remote table as the qualifier. - let new_schema = (*new_table_scan.projected_schema) - .clone() - .replace_qualifier(remote_table_name.clone()); - new_table_scan.projected_schema = Arc::new(new_schema); - new_table_scan.table_name = remote_table_name; - } - None => { - // Not a SQLTableSource (is this possible?) - return Ok(plan.clone()); - } - } - - return Ok(LogicalPlan::TableScan(new_table_scan)); - } else { - return Ok(plan.clone()); - } - } - - let rewritten_inputs = plan - .inputs() - .into_iter() - .map(|i| rewrite_table_scans(i, known_rewrites)) - .collect::>>()?; - - if let LogicalPlan::Limit(limit) = plan { - let rewritten_skip = limit - .skip - .as_ref() - .map(|skip| rewrite_table_scans_in_expr(*skip.clone(), known_rewrites).map(Box::new)) - .transpose()?; - - let rewritten_fetch = limit - .fetch - .as_ref() - .map(|fetch| rewrite_table_scans_in_expr(*fetch.clone(), known_rewrites).map(Box::new)) - .transpose()?; - - // explicitly set fetch and skip - let new_plan = LogicalPlan::Limit(Limit { - skip: rewritten_skip, - fetch: rewritten_fetch, - input: Arc::new(rewritten_inputs[0].clone()), - }); - - return Ok(new_plan); - } - - let mut new_expressions = vec![]; - for expression in plan.expressions() { - let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?; - new_expressions.push(new_expr); - } - - let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; - - Ok(new_plan) -} - -// The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite. -// The name to rewrite should NOT be a substring of another name. -// Supports multiple occurrences of table_ref_str in col_name. -fn rewrite_column_name_in_expr( - col_name: &str, - table_ref_str: &str, - rewrite: &str, - start_pos: usize, -) -> Option { - if start_pos >= col_name.len() { - return None; - } - - // Find the first occurrence of table_ref_str starting from start_pos - let idx = col_name[start_pos..].find(table_ref_str)?; - - // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos - let idx = start_pos + idx; - - if idx > 0 { - // Check if the previous character is alphabetic, numeric, underscore or period, in which case we - // should not rewrite as it is a part of another name. - if let Some(prev_char) = col_name.chars().nth(idx - 1) { - if prev_char.is_alphabetic() - || prev_char.is_numeric() - || prev_char == '_' - || prev_char == '.' - { - return rewrite_column_name_in_expr( - col_name, - table_ref_str, - rewrite, - idx + table_ref_str.len(), - ); - } - } - } - - // Check if the next character is alphabetic, numeric or underscore, in which case we - // should not rewrite as it is a part of another name. - if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) { - if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' { - return rewrite_column_name_in_expr( - col_name, - table_ref_str, - rewrite, - idx + table_ref_str.len(), - ); - } - } - - // Found full match, replace table_ref_str occurrence with rewrite - let rewritten_name = format!( - "{}{}{}", - &col_name[..idx], - rewrite, - &col_name[idx + table_ref_str.len()..] - ); - // Check if the rewritten name contains more occurrence of table_ref_str, and rewrite them as well - // This is done by providing the updated start_pos for search - match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len()) - { - Some(new_name) => Some(new_name), // more occurrences found - None => Some(rewritten_name), // no more occurrences/changes - } -} - -fn rewrite_table_scans_in_expr( - expr: Expr, - known_rewrites: &mut HashMap, -) -> Result { - match expr { - Expr::ScalarSubquery(subquery) => { - let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?; - let outer_ref_columns = subquery - .outer_ref_columns - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(new_subquery), - outer_ref_columns, - })) - } - Expr::BinaryExpr(binary_expr) => { - let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?; - let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?; - Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left), - binary_expr.op, - Box::new(right), - ))) - } - Expr::Column(mut col) => { - if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { - Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name))) - } else { - // This prevent over-eager rewrite and only pass the column into below rewritten - // rule like MAX(...) - if col.relation.is_some() { - return Ok(Expr::Column(col)); - } - - // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. - // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" - let (new_name, was_rewritten) = known_rewrites.iter().fold( - (col.name.to_string(), false), - |(col_name, was_rewritten), (table_ref, rewrite)| { - match rewrite_column_name_in_expr( - &col_name, - &table_ref.to_string(), - &rewrite.to_string(), - 0, - ) { - Some(new_name) => (new_name, true), - None => (col_name, was_rewritten), - } - }, - ); - if was_rewritten { - Ok(Expr::Column(Column::new(col.relation.take(), new_name))) - } else { - Ok(Expr::Column(col)) - } - } - } - Expr::Alias(alias) => { - let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?; - if let Some(relation) = &alias.relation { - if let Some(rewrite) = known_rewrites.get(relation) { - return Ok(Expr::Alias(Alias::new( - expr, - Some(rewrite.clone()), - alias.name, - ))); - } - } - Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name))) - } - Expr::Like(like) => { - let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?; - let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?; - Ok(Expr::Like(Like::new( - like.negated, - Box::new(expr), - Box::new(pattern), - like.escape_char, - like.case_insensitive, - ))) - } - Expr::SimilarTo(similar_to) => { - let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?; - let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?; - Ok(Expr::SimilarTo(Like::new( - similar_to.negated, - Box::new(expr), - Box::new(pattern), - similar_to.escape_char, - similar_to.case_insensitive, - ))) - } - Expr::Not(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::Not(Box::new(expr))) - } - Expr::IsNotNull(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotNull(Box::new(expr))) - } - Expr::IsNull(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNull(Box::new(expr))) - } - Expr::IsTrue(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsTrue(Box::new(expr))) - } - Expr::IsFalse(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsFalse(Box::new(expr))) - } - Expr::IsUnknown(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsUnknown(Box::new(expr))) - } - Expr::IsNotTrue(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotTrue(Box::new(expr))) - } - Expr::IsNotFalse(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotFalse(Box::new(expr))) - } - Expr::IsNotUnknown(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotUnknown(Box::new(expr))) - } - Expr::Negative(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::Negative(Box::new(expr))) - } - Expr::Between(between) => { - let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?; - let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?; - let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?; - Ok(Expr::Between(Between::new( - Box::new(expr), - between.negated, - Box::new(low), - Box::new(high), - ))) - } - Expr::Case(case) => { - let expr = case - .expr - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) - .transpose()? - .map(Box::new); - let else_expr = case - .else_expr - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) - .transpose()? - .map(Box::new); - let when_expr = case - .when_then_expr - .into_iter() - .map(|(when, then)| { - let when = rewrite_table_scans_in_expr(*when, known_rewrites); - let then = rewrite_table_scans_in_expr(*then, known_rewrites); - - match (when, then) { - (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))), - (Err(e), _) | (_, Err(e)) => Err(e), - } - }) - .collect::, Box)>>>()?; - Ok(Expr::Case(Case::new(expr, when_expr, else_expr))) - } - Expr::Cast(cast) => { - let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?; - Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type))) - } - Expr::TryCast(try_cast) => { - let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?; - Ok(Expr::TryCast(TryCast::new( - Box::new(expr), - try_cast.data_type, - ))) - } - Expr::ScalarFunction(sf) => { - let args = sf - .args - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::ScalarFunction(ScalarFunction { - func: sf.func, - args, - })) - } - Expr::AggregateFunction(af) => { - let args = af - .params - .args - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let filter = af - .params - .filter - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) - .transpose()? - .map(Box::new); - let order_by = af - .params - .order_by - .map(|e| { - e.into_iter() - .map(|sort| { - Ok(Sort { - expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?, - ..sort - }) - }) - .collect::>>() - }) - .transpose()?; - let params = AggregateFunctionParams { - args, - distinct: af.params.distinct, - filter, - order_by, - null_treatment: af.params.null_treatment, - }; - Ok(Expr::AggregateFunction(AggregateFunction { - func: af.func, - params, - })) - } - Expr::WindowFunction(wf) => { - let args = wf - .params - .args - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let partition_by = wf - .params - .partition_by - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let order_by = wf - .params - .order_by - .into_iter() - .map(|sort| { - Ok(Sort { - expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?, - ..sort - }) - }) - .collect::>>()?; - let params = WindowFunctionParams { - args, - partition_by, - order_by, - window_frame: wf.params.window_frame, - null_treatment: wf.params.null_treatment, - }; - Ok(Expr::WindowFunction(WindowFunction { - fun: wf.fun, - params, - })) - } - Expr::InList(il) => { - let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?; - let list = il - .list - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated))) - } - Expr::Exists(exists) => { - let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?; - let outer_ref_columns = exists - .subquery - .outer_ref_columns - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let subquery = Subquery { - subquery: Arc::new(subquery_plan), - outer_ref_columns, - }; - Ok(Expr::Exists(Exists::new(subquery, exists.negated))) - } - Expr::InSubquery(is) => { - let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?; - let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?; - let outer_ref_columns = is - .subquery - .outer_ref_columns - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let subquery = Subquery { - subquery: Arc::new(subquery_plan), - outer_ref_columns, - }; - Ok(Expr::InSubquery(InSubquery::new( - Box::new(expr), - subquery, - is.negated, - ))) - } - // TODO: remove the next line after `Expr::Wildcard` is removed in datafusion - #[expect(deprecated)] - Expr::Wildcard { qualifier, options } => { - let options = WildcardOptions { - replace: options - .replace - .map(|replace| -> Result { - Ok(PlannedReplaceSelectItem { - planned_expressions: replace - .planned_expressions - .into_iter() - .map(|expr| rewrite_table_scans_in_expr(expr, known_rewrites)) - .collect::>>()?, - ..replace - }) - }) - .transpose()?, - ..*options - }; - if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) { - Ok(Expr::Wildcard { - qualifier: Some(rewrite.clone()), - options: Box::new(options), - }) - } else { - Ok(Expr::Wildcard { - qualifier, - options: Box::new(options), - }) - } - } - Expr::GroupingSet(gs) => match gs { - GroupingSet::Rollup(exprs) => { - let exprs = exprs - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs))) - } - GroupingSet::Cube(exprs) => { - let exprs = exprs - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::GroupingSet(GroupingSet::Cube(exprs))) - } - GroupingSet::GroupingSets(vec_exprs) => { - let vec_exprs = vec_exprs - .into_iter() - .map(|exprs| { - exprs - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>() - }) - .collect::>>>()?; - Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs))) - } - }, - Expr::OuterReferenceColumn(dt, col) => { - if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { - Ok(Expr::OuterReferenceColumn( - dt, - Column::new(Some(rewrite.clone()), &col.name), - )) - } else { - Ok(Expr::OuterReferenceColumn(dt, col)) - } - } - Expr::Unnest(unnest) => { - let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?; - Ok(Expr::Unnest(Unnest::new(expr))) - } - Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr), - } -} - +#[derive(Debug)] struct SQLFederationPlanner { executor: Arc, } @@ -711,41 +177,134 @@ impl VirtualExecutionPlan { Arc::new(Schema::from(df_schema)) } - fn sql(&self) -> Result { - // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. - let mut known_rewrites = HashMap::new(); - let plan = &rewrite_table_scans(&self.plan, &mut known_rewrites)?; - let mut ast = self.plan_to_sql(plan)?; + fn final_sql(&self) -> Result { + let plan = self.plan.clone(); + let plan = RewriteTableScanAnalyzer::rewrite(plan)?; + let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?; + let plan = apply_logical_optimizers(plan, logical_optimizers)?; + let ast = self.plan_to_statement(&plan)?; + let ast = self.rewrite_with_executor_ast_analyzer(ast)?; + let ast = apply_ast_analyzers(ast, ast_analyzers)?; + Ok(ast.to_string()) + } - if let Some(analyzer) = self.executor.ast_analyzer() { - ast = analyzer(ast)?; + fn rewrite_with_executor_ast_analyzer( + &self, + ast: Statement, + ) -> Result { + if let Some(mut analyzer) = self.executor.ast_analyzer() { + Ok(analyzer(ast)?) + } else { + Ok(ast) } - - Ok(format!("{ast}")) } - fn plan_to_sql(&self, plan: &LogicalPlan) -> Result { + fn plan_to_statement(&self, plan: &LogicalPlan) -> Result { Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan) } } +fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec, Vec)> { + let mut logical_optimizers = vec![]; + let mut ast_analyzers = vec![]; + + plan.apply(|node| { + if let LogicalPlan::TableScan(table) = node { + let provider = get_table_source(&table.source) + .expect("caller is virtual exec so this is valid") + .expect("caller is virtual exec so this is valid"); + if let Some(source) = provider.as_any().downcast_ref::() { + if let Some(analyzer) = source.table.logical_optimizer() { + logical_optimizers.push(analyzer); + } + if let Some(analyzer) = source.table.ast_analyzer() { + ast_analyzers.push(analyzer); + } + } + } + Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue) + })?; + + Ok((logical_optimizers, ast_analyzers)) +} + +fn apply_logical_optimizers( + mut plan: LogicalPlan, + analyzers: Vec, +) -> Result { + for mut analyzer in analyzers { + let old_schema = plan.schema().clone(); + plan = analyzer(plan)?; + let new_schema = plan.schema(); + if &old_schema != new_schema { + return Err(DataFusionError::Execution(format!( + "Schema altered during logical analysis, expected: {}, found: {}", + old_schema, new_schema + ))); + } + } + Ok(plan) +} + +fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec) -> Result { + for mut analyzer in analyzers { + statement = analyzer(statement)?; + } + Ok(statement) +} + impl DisplayAs for VirtualExecutionPlan { fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result { write!(f, "VirtualExecutionPlan")?; - let Ok(ast) = plan_to_sql(&self.plan) else { - return Ok(()); - }; write!(f, " name={}", self.executor.name())?; if let Some(ctx) = self.executor.compute_context() { write!(f, " compute_context={ctx}")?; }; + let mut plan = self.plan.clone(); + if let Ok(statement) = self.plan_to_statement(&plan) { + write!(f, " initial_sql={statement}")?; + } - write!(f, " sql={ast}")?; - if let Ok(query) = self.sql() { - write!(f, " rewritten_sql={query}")?; + let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) { + Ok(analyzers) => analyzers, + Err(_) => return Ok(()), }; - write!(f, " sql={ast}") + let old_plan = plan.clone(); + + plan = match apply_logical_optimizers(plan, logical_optimizers) { + Ok(plan) => plan, + _ => return Ok(()), + }; + + let statement = match self.plan_to_statement(&plan) { + Ok(statement) => statement, + _ => return Ok(()), + }; + + if plan != old_plan { + write!(f, " rewritten_logical_sql={statement}")?; + } + + let old_statement = statement.clone(); + let statement = match self.rewrite_with_executor_ast_analyzer(statement) { + Ok(statement) => statement, + _ => return Ok(()), + }; + if old_statement != statement { + write!(f, " rewritten_executor_sql={statement}")?; + } + + let old_statement = statement.clone(); + let statement = match apply_ast_analyzers(statement, ast_analyzers) { + Ok(statement) => statement, + _ => return Ok(()), + }; + if old_statement != statement { + write!(f, " rewritten_ast_analyzer={statement}")?; + } + + Ok(()) } } @@ -778,8 +337,7 @@ impl ExecutionPlan for VirtualExecutionPlan { _partition: usize, _context: Arc, ) -> Result { - let query = self.plan_to_sql(&self.plan)?.to_string(); - self.executor.execute(query.as_str(), self.schema()) + self.executor.execute(&self.final_sql()?, self.schema()) } fn properties(&self) -> &PlanProperties { @@ -789,304 +347,260 @@ impl ExecutionPlan for VirtualExecutionPlan { #[cfg(test)] mod tests { + + use std::collections::HashSet; + use std::sync::Arc; + + use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource}; use crate::FederatedTableProviderAdaptor; + use async_trait::async_trait; + use datafusion::arrow::datatypes::{Schema, SchemaRef}; + use datafusion::common::tree_node::TreeNodeRecursion; + use datafusion::execution::SendableRecordBatchStream; + use datafusion::sql::unparser::dialect::Dialect; + use datafusion::sql::unparser::{self}; use datafusion::{ arrow::datatypes::{DataType, Field}, - catalog::{MemorySchemaProvider, SchemaProvider}, - common::Column, - datasource::{DefaultTableSource, TableProvider}, - error::DataFusionError, + datasource::TableProvider, execution::context::SessionContext, - logical_expr::LogicalPlanBuilder, - sql::{unparser::dialect::DefaultDialect, unparser::dialect::Dialect}, }; + use super::table::RemoteTable; use super::*; - struct TestSQLExecutor {} + #[derive(Debug, Clone)] + struct TestExecutor { + compute_context: String, + } #[async_trait] - impl SQLExecutor for TestSQLExecutor { + impl SQLExecutor for TestExecutor { fn name(&self) -> &str { - "test_sql_table_source" + "TestExecutor" } fn compute_context(&self) -> Option { - None + Some(self.compute_context.clone()) } fn dialect(&self) -> Arc { - Arc::new(DefaultDialect {}) + Arc::new(unparser::dialect::DefaultDialect {}) } fn execute(&self, _query: &str, _schema: SchemaRef) -> Result { - Err(DataFusionError::NotImplemented( - "execute not implemented".to_string(), - )) + unimplemented!() } async fn table_names(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "table inference not implemented".to_string(), - )) + unimplemented!() } async fn get_table_schema(&self, _table_name: &str) -> Result { - Err(DataFusionError::NotImplemented( - "table inference not implemented".to_string(), - )) + unimplemented!() } } - fn get_test_table_provider() -> Arc { - let sql_federation_provider = - Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {}))); - + fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, false), Field::new("b", DataType::Utf8, false), Field::new("c", DataType::Date32, false), ])); - let table_source = Arc::new( - SQLTableSource::new_with_schema( - sql_federation_provider, - "remote_table".to_string(), - schema, - ) - .expect("to have a valid SQLTableSource"), - ); + let table_ref = RemoteTableRef::try_from(name).unwrap(); + let table = Arc::new(RemoteTable::new(table_ref, schema)); + let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor))); + let table_source = Arc::new(SQLTableSource { provider, table }); Arc::new(FederatedTableProviderAdaptor::new(table_source)) } - fn get_test_table_source() -> Arc { - Arc::new(DefaultTableSource::new(get_test_table_provider())) - } - - fn get_test_df_context() -> SessionContext { - let ctx = SessionContext::new(); - let catalog = ctx - .catalog("datafusion") - .expect("default catalog is datafusion"); - let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc; - catalog - .register_schema("foo", Arc::clone(&foo_schema)) - .expect("to register schema"); - foo_schema - .register_table("df_table".to_string(), get_test_table_provider()) - .expect("to register table"); - - let public_schema = catalog - .schema("public") - .expect("public schema should exist"); - public_schema - .register_table("app_table".to_string(), get_test_table_provider()) - .expect("to register table"); - - ctx - } - - #[test] - fn test_rewrite_table_scans_basic() -> Result<()> { - let default_table_source = get_test_table_source(); - let plan = - LogicalPlanBuilder::scan("foo.df_table", default_table_source, None)?.project(vec![ - Expr::Column(Column::from_qualified_name("foo.df_table.a")), - Expr::Column(Column::from_qualified_name("foo.df_table.b")), - Expr::Column(Column::from_qualified_name("foo.df_table.c")), - ])?; + #[tokio::test] + async fn basic_sql_federation_test() -> Result<(), DataFusionError> { + let test_executor_a = TestExecutor { + compute_context: "a".into(), + }; - let mut known_rewrites = HashMap::new(); - let rewritten_plan = rewrite_table_scans(&plan.build()?, &mut known_rewrites)?; + let test_executor_b = TestExecutor { + compute_context: "b".into(), + }; - println!("rewritten_plan: \n{:#?}", rewritten_plan); + let table_a1_ref = "table_a1".to_string(); + let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone()); + + let table_a2_ref = "table_a2".to_string(); + let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a); + + let table_b1_ref = "table_b1(1)".to_string(); + let table_b1_df_ref = "table_local_b1".to_string(); + + let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b); + + // Create a new SessionState with the optimizer rule we created above + let state = crate::default_session_state(); + let ctx = SessionContext::new_with_state(state); + + ctx.register_table(table_a1_ref.clone(), table_a1).unwrap(); + ctx.register_table(table_a2_ref.clone(), table_a2).unwrap(); + ctx.register_table(table_b1_df_ref.clone(), table_b1) + .unwrap(); + + let query = r#" + SELECT * FROM table_a1 + UNION ALL + SELECT * FROM table_a2 + UNION ALL + SELECT * FROM table_local_b1; + "#; + + let df = ctx.sql(query).await?; + + let logical_plan = df.into_optimized_plan()?; + + let mut table_a1_federated = false; + let mut table_a2_federated = false; + let mut table_b1_federated = false; + + let _ = logical_plan.apply(|node| { + if let LogicalPlan::Extension(node) = node { + if let Some(node) = node.node.as_any().downcast_ref::() { + let _ = node.plan().apply(|node| { + if let LogicalPlan::TableScan(table) = node { + if table.table_name.table() == table_a1_ref { + table_a1_federated = true; + } + if table.table_name.table() == table_a2_ref { + table_a2_federated = true; + } + // assuming table name is rewritten via analyzer + if table.table_name.table() == table_b1_df_ref { + table_b1_federated = true; + } + } + Ok(TreeNodeRecursion::Continue) + }); + } + } + Ok(TreeNodeRecursion::Continue) + }); - let unparsed_sql = plan_to_sql(&rewritten_plan)?; + assert!(table_a1_federated); + assert!(table_a2_federated); + assert!(table_b1_federated); - println!("unparsed_sql: \n{unparsed_sql}"); + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; - assert_eq!( - format!("{unparsed_sql}"), - r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"# - ); + let mut final_queries = vec![]; - Ok(()) - } + let _ = physical_plan.apply(|node| { + if node.name() == "sql_federation_exec" { + let node = node + .as_any() + .downcast_ref::() + .unwrap(); - fn init_tracing() { - let subscriber = tracing_subscriber::FmtSubscriber::builder() - .with_env_filter("debug") - .with_ansi(true) - .finish(); - let _ = tracing::subscriber::set_global_default(subscriber); - } + final_queries.push(node.final_sql()?); + } + Ok(TreeNodeRecursion::Continue) + }); - #[tokio::test] - async fn test_rewrite_table_scans_agg() -> Result<()> { - init_tracing(); - let ctx = get_test_df_context(); - - let agg_tests = vec![ - ( - "SELECT MAX(a) FROM foo.df_table", - r#"SELECT max(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT foo.df_table.a FROM foo.df_table", - r#"SELECT remote_table.a FROM remote_table"#, - ), - ( - "SELECT MIN(a) FROM foo.df_table", - r#"SELECT min(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT AVG(a) FROM foo.df_table", - r#"SELECT avg(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT SUM(a) FROM foo.df_table", - r#"SELECT sum(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT COUNT(a) FROM foo.df_table", - r#"SELECT count(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT COUNT(a) as cnt FROM foo.df_table", - r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, - ), - ( - "SELECT COUNT(a) as cnt FROM foo.df_table", - r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, - ), - ( - "SELECT app_table from (SELECT a as app_table FROM app_table) b", - r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, - ), - ( - "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b", - r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, - ), - // multiple occurrences of the same table in single aggregation expression - ( - "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table", - r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#, - ), - // different tables in single aggregation expression - ( - "SELECT COUNT(CASE WHEN appt.a > 0 THEN appt.a ELSE dft.a END) FROM app_table as appt, foo.df_table as dft", - "SELECT count(CASE WHEN (appt.a > 0) THEN appt.a ELSE dft.a END) FROM remote_table AS appt CROSS JOIN remote_table AS dft" - ), + let expected = vec![ + "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1", + "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2", + "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1)", ]; - for test in agg_tests { - test_sql(&ctx, test.0, test.1).await?; - } + assert_eq!( + HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())), + HashSet::from_iter(expected) + ); Ok(()) } #[tokio::test] - async fn test_rewrite_table_scans_alias() -> Result<()> { - init_tracing(); - let ctx = get_test_df_context(); - - let tests = vec![ - ( - "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)", - r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, - ), - ( - "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)", - r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, - ), - ( - "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)", - r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#, - ), - ]; - - for test in tests { - test_sql(&ctx, test.0, test.1).await?; - } + async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> { + let test_executor_a = TestExecutor { + compute_context: "test".into(), + }; - Ok(()) - } + let lowercase_table_ref = "default.table".to_string(); + let lowercase_local_table_ref = "dftable".to_string(); + let lowercase_table = + get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone()); + + let capitalized_table_ref = "default.Table(1)".to_string(); + let capitalized_local_table_ref = "dfview".to_string(); + let capitalized_table = + get_test_table_provider(capitalized_table_ref.clone(), test_executor_a); + + // Create a new SessionState with the optimizer rule we created above + let state = crate::default_session_state(); + let ctx = SessionContext::new_with_state(state); + + ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table) + .unwrap(); + ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table) + .unwrap(); + + let query = r#" + SELECT * FROM dftable + UNION ALL + SELECT * FROM dfview; + "#; + + let df = ctx.sql(query).await?; + + let logical_plan = df.into_optimized_plan()?; + + let mut lowercase_table = false; + let mut capitalized_table = false; + + let _ = logical_plan.apply(|node| { + if let LogicalPlan::Extension(node) = node { + if let Some(node) = node.node.as_any().downcast_ref::() { + let _ = node.plan().apply(|node| { + if let LogicalPlan::TableScan(table) = node { + if table.table_name.table() == lowercase_local_table_ref { + lowercase_table = true; + } + if table.table_name.table() == capitalized_local_table_ref { + capitalized_table = true; + } + } + Ok(TreeNodeRecursion::Continue) + }); + } + } + Ok(TreeNodeRecursion::Continue) + }); - async fn test_sql( - ctx: &SessionContext, - sql_query: &str, - expected_sql: &str, - ) -> Result<(), datafusion::error::DataFusionError> { - let data_frame = ctx.sql(sql_query).await?; + assert!(lowercase_table); + assert!(capitalized_table); - println!("before optimization: \n{:#?}", data_frame.logical_plan()); + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; - let mut known_rewrites = HashMap::new(); - let rewritten_plan = rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?; + let mut final_queries = vec![]; - println!("rewritten_plan: \n{:#?}", rewritten_plan); + let _ = physical_plan.apply(|node| { + if node.name() == "sql_federation_exec" { + let node = node + .as_any() + .downcast_ref::() + .unwrap(); - let unparsed_sql = plan_to_sql(&rewritten_plan)?; + final_queries.push(node.final_sql()?); + } + Ok(TreeNodeRecursion::Continue) + }); - println!("unparsed_sql: \n{unparsed_sql}"); + let expected = vec![ + r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1)"#, + ]; assert_eq!( - format!("{unparsed_sql}"), - expected_sql, - "SQL under test: {}", - sql_query + HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())), + HashSet::from_iter(expected) ); Ok(()) } - - #[tokio::test] - async fn test_rewrite_table_scans_limit_offset() -> Result<()> { - init_tracing(); - let ctx = get_test_df_context(); - - let tests = vec![ - // Basic LIMIT - ( - "SELECT a FROM foo.df_table LIMIT 5", - r#"SELECT remote_table.a FROM remote_table LIMIT 5"#, - ), - // Basic OFFSET - ( - "SELECT a FROM foo.df_table OFFSET 5", - r#"SELECT remote_table.a FROM remote_table OFFSET 5"#, - ), - // OFFSET after LIMIT - ( - "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5", - r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, - ), - // LIMIT after OFFSET - ( - "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10", - r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, - ), - // Zero OFFSET - ( - "SELECT a FROM foo.df_table OFFSET 0", - r#"SELECT remote_table.a FROM remote_table OFFSET 0"#, - ), - // Zero LIMIT - ( - "SELECT a FROM foo.df_table LIMIT 0", - r#"SELECT remote_table.a FROM remote_table LIMIT 0"#, - ), - // Zero LIMIT and OFFSET - ( - "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0", - r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#, - ), - ]; - - for test in tests { - test_sql(&ctx, test.0, test.1).await?; - } - - Ok(()) - } } diff --git a/datafusion-federation/src/sql/schema.rs b/datafusion-federation/src/sql/schema.rs index 1961226..a459d72 100644 --- a/datafusion-federation/src/sql/schema.rs +++ b/datafusion-federation/src/sql/schema.rs @@ -1,45 +1,86 @@ use std::{any::Any, sync::Arc}; use async_trait::async_trait; -use datafusion::logical_expr::{TableSource, TableType}; use datafusion::{ - arrow::datatypes::SchemaRef, catalog::SchemaProvider, datasource::TableProvider, error::Result, + catalog::SchemaProvider, + datasource::TableProvider, + error::{DataFusionError, Result}, + sql::TableReference, }; use futures::future::join_all; -use crate::{ - sql::SQLFederationProvider, FederatedTableProviderAdaptor, FederatedTableSource, - FederationProvider, -}; +use super::{table::SQLTable, RemoteTableRef, SQLTableSource}; +use crate::{sql::SQLFederationProvider, FederatedTableProviderAdaptor}; +/// An in-memory schema provider for SQL tables. #[derive(Debug)] pub struct SQLSchemaProvider { - // provider: Arc, tables: Vec>, } impl SQLSchemaProvider { + /// Creates a new SQLSchemaProvider from a [`SQLFederationProvider`]. + /// Initializes the schema provider by fetching table names and schema from the federation provider's executor, pub async fn new(provider: Arc) -> Result { - let tables = Arc::clone(&provider).executor.table_names().await?; + let executor = Arc::clone(&provider.executor); + let tables = executor + .table_names() + .await? + .iter() + .map(RemoteTableRef::try_from) + .collect::>>()?; + + let tasks = tables + .into_iter() + .map(|table_ref| { + let provider = Arc::clone(&provider); + async move { SQLTableSource::new(provider, table_ref).await } + }) + .collect::>(); + + let tables = join_all(tasks) + .await + .into_iter() + .map(|res| res.map(Arc::new)) + .collect::>>()?; - Self::new_with_tables(provider, tables).await + Ok(Self { tables }) } - pub async fn new_with_tables( + /// Creates a new SQLSchemaProvider from a SQLFederationProvider and a list of table references. + /// Fetches the schema for each table using the executor's implementation. + pub async fn new_with_tables( provider: Arc, - tables: Vec, - ) -> Result { + tables: Vec, + ) -> Result + where + T: TryInto, + { + let tables = tables + .into_iter() + .map(|t| t.try_into()) + .collect::>>()?; let futures: Vec<_> = tables .into_iter() .map(|t| SQLTableSource::new(Arc::clone(&provider), t)) .collect(); let results: Result> = join_all(futures).await.into_iter().collect(); - let sources = results?.into_iter().map(Arc::new).collect(); - Ok(Self::new_with_table_sources(sources)) + let tables = results?.into_iter().map(Arc::new).collect(); + Ok(Self { tables }) } - pub fn new_with_table_sources(tables: Vec>) -> Self { - Self { tables } + /// Creates a new SQLSchemaProvider from a SQLFederationProvider and a list of custom table instances. + pub fn new_with_custom_tables( + provider: Arc, + tables: Vec>, + ) -> Self { + Self { + tables: tables + .into_iter() + .map(|table| SQLTableSource::new_with_table(provider.clone(), table)) + .map(Arc::new) + .collect(), + } } } @@ -50,27 +91,30 @@ impl SchemaProvider for SQLSchemaProvider { } fn table_names(&self) -> Vec { - self.tables.iter().map(|s| s.table_name.clone()).collect() + self.tables + .iter() + .map(|source| source.table_reference().to_quoted_string()) + .collect() } async fn table(&self, name: &str) -> Result>> { - if let Some(source) = self - .tables - .iter() - .find(|s| s.table_name.eq_ignore_ascii_case(name)) - { - let adaptor = FederatedTableProviderAdaptor::new( - Arc::clone(source) as Arc - ); + if let Some(source) = self.tables.iter().find(|s| { + s.table_reference() + .to_quoted_string() + .eq_ignore_ascii_case(name) + }) { + let adaptor = FederatedTableProviderAdaptor::new(source.clone()); return Ok(Some(Arc::new(adaptor))); } Ok(None) } fn table_exist(&self, name: &str) -> bool { - self.tables - .iter() - .any(|s| s.table_name.eq_ignore_ascii_case(name)) + self.tables.iter().any(|source| { + source + .table_reference() + .resolved_eq(&TableReference::from(name)) + }) } } @@ -108,55 +152,3 @@ impl SchemaProvider for MultiSchemaProvider { self.children.iter().any(|p| p.table_exist(name)) } } - -#[derive(Debug)] -pub struct SQLTableSource { - provider: Arc, - table_name: String, - schema: SchemaRef, -} - -impl SQLTableSource { - // creates a SQLTableSource and infers the table schema - pub async fn new(provider: Arc, table_name: String) -> Result { - let schema = Arc::clone(&provider) - .executor - .get_table_schema(table_name.as_str()) - .await?; - Self::new_with_schema(provider, table_name, schema) - } - - pub fn new_with_schema( - provider: Arc, - table_name: String, - schema: SchemaRef, - ) -> Result { - Ok(Self { - provider, - table_name, - schema, - }) - } - - pub fn table_name(&self) -> &str { - self.table_name.as_str() - } -} - -impl FederatedTableSource for SQLTableSource { - fn federation_provider(&self) -> Arc { - Arc::clone(&self.provider) as Arc - } -} - -impl TableSource for SQLTableSource { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) - } - fn table_type(&self) -> TableType { - TableType::Temporary - } -} diff --git a/datafusion-federation/src/sql/table.rs b/datafusion-federation/src/sql/table.rs new file mode 100644 index 0000000..1626bad --- /dev/null +++ b/datafusion-federation/src/sql/table.rs @@ -0,0 +1,170 @@ +use crate::sql::SQLFederationProvider; +use crate::FederatedTableSource; +use crate::FederationProvider; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::error::Result; +use datafusion::logical_expr::TableSource; +use datafusion::logical_expr::TableType; +use datafusion::sql::TableReference; +use std::any::Any; +use std::sync::Arc; + +use super::ast_analyzer; +use super::executor::LogicalOptimizer; +use super::AstAnalyzer; +use super::RemoteTableRef; + +/// Trait to represent a SQL remote table inside [`SQLTableSource`]. +/// A remote table provides information such as schema, table reference, and +/// provides hooks for rewriting the logical plan and AST before execution. +/// This crate provides [`RemoteTable`] as a default ready-to-use type. +pub trait SQLTable: std::fmt::Debug + Send + Sync { + /// Returns a reference as a trait object. + fn as_any(&self) -> &dyn Any; + /// Provides the [`TableReference`](`datafusion::sql::TableReference`) used to identify the table in SQL queries. + /// This TableReference is used for registering the table with the [`SQLSchemaProvider`](`super::SQLSchemaProvider`). + /// If the table provider is registered in the Datafusion context under a different name, + /// the logical plan will be rewritten to use this table reference during execution. + /// Therefore, any AST analyzer should match against this table reference. + fn table_reference(&self) -> TableReference; + /// Schema of the remote table + fn schema(&self) -> SchemaRef; + /// Returns a logical optimizer specific to this table, will be used to modify the logical plan before execution + fn logical_optimizer(&self) -> Option { + None + } + /// Returns an AST analyzer specific to this table, will be used to modify the AST before execution + fn ast_analyzer(&self) -> Option { + None + } +} + +/// Represents a remote table with a reference and schema. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RemoteTable { + remote_table_ref: RemoteTableRef, + schema: SchemaRef, +} + +impl RemoteTable { + /// Creates a new `RemoteTable` instance. + /// + /// Examples: + /// ```rust + /// use datafusion::sql::TableReference; + /// + /// RemoteTable::new("myschema.table".try_into()?, schema); + /// RemoteTable::new(r#"myschema."Table""#.try_into()?, schema); + /// RemoteTable::new(TableReference::partial("myschema", "table").into(), schema); + /// RemoteTable::new("myschema.view('obj')".try_into()?, schema); + /// RemoteTable::new("myschema.view(name => 'obj')".try_into()?, schema); + /// RemoteTable::new("myschema.view(name = 'obj')".try_into()?, schema); + /// ``` + pub fn new(table_ref: RemoteTableRef, schema: SchemaRef) -> Self { + Self { + remote_table_ref: table_ref, + schema, + } + } + + /// Return table reference of this remote table. + /// Only returns the object name, ignoring functional params if any + pub fn table_reference(&self) -> &TableReference { + self.remote_table_ref.table_ref() + } + + pub fn schema(&self) -> &SchemaRef { + &self.schema + } +} + +impl SQLTable for RemoteTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_reference(&self) -> TableReference { + Self::table_reference(self).clone() + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn logical_optimizer(&self) -> Option { + None + } + + /// Returns ast analyzer that modifies table that contains functional args after table ident + fn ast_analyzer(&self) -> Option { + if let Some(args) = self.remote_table_ref.args() { + Some( + ast_analyzer::TableArgReplace::default() + .with(self.remote_table_ref.table_ref().clone(), args.to_vec()) + .into_analyzer(), + ) + } else { + None + } + } +} + +#[derive(Debug, Clone)] +pub struct SQLTableSource { + pub provider: Arc, + pub table: Arc, +} + +impl SQLTableSource { + // creates a SQLTableSource and infers the table schema + pub async fn new( + provider: Arc, + table_ref: RemoteTableRef, + ) -> Result { + let table_name = table_ref.to_quoted_string(); + let schema = provider.executor.get_table_schema(&table_name).await?; + Ok(Self::new_with_schema(provider, table_ref, schema)) + } + + /// Create a SQLTableSource with a table reference and schema + pub fn new_with_schema( + provider: Arc, + table_ref: RemoteTableRef, + schema: SchemaRef, + ) -> Self { + Self { + provider, + table: Arc::new(RemoteTable::new(table_ref, schema)), + } + } + + /// Create new with a custom SQLtable instance. + pub fn new_with_table(provider: Arc, table: Arc) -> Self { + Self { provider, table } + } + + /// Return associated table reference of stored remote table + pub fn table_reference(&self) -> TableReference { + self.table.table_reference() + } +} + +impl TableSource for SQLTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.table.schema() + } + + fn table_type(&self) -> TableType { + TableType::Temporary + } +} + +impl FederatedTableSource for SQLTableSource { + fn federation_provider(&self) -> Arc { + Arc::clone(&self.provider) as Arc + } +} diff --git a/datafusion-federation/src/sql/table_reference.rs b/datafusion-federation/src/sql/table_reference.rs new file mode 100644 index 0000000..bb2cc2f --- /dev/null +++ b/datafusion-federation/src/sql/table_reference.rs @@ -0,0 +1,280 @@ +use std::sync::Arc; + +use datafusion::{ + error::DataFusionError, + sql::{ + sqlparser::{ + self, + ast::FunctionArg, + dialect::{Dialect, GenericDialect}, + tokenizer::Token, + }, + TableReference, + }, +}; + +/// A multipart identifier to a remote table, view or parameterized view. +/// +/// RemoteTableRef can be created by parsing from a string represeting a table obbject with optional +/// ```rust +/// +/// RemoteTableRef::try_from("myschema.table"); +/// RemoteTableRef::try_from(r#"myschema."Table""#); +/// RemoteTableRef::try_from("myschema.view('obj')"); +/// +/// RemoteTableRef::parse_with_dialect("myschema.view(name = 'obj')", &PostgresSqlDialect {}); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RemoteTableRef { + pub table_ref: TableReference, + pub args: Option>, +} + +impl RemoteTableRef { + /// Get quoted_string representation for the table it is referencing, this is same as calling to_quoted_string on the inner table reference. + pub fn to_quoted_string(&self) -> String { + self.table_ref.to_quoted_string() + } + + /// Create new using general purpose dialect. Prefer [`Self::parse_with_dialect`] if the dialect is known beforehand + pub fn parse_with_default_dialect(s: &str) -> Result { + Self::parse_with_dialect(s, &GenericDialect {}) + } + + /// Create new using a specfic instance of dialect. + pub fn parse_with_dialect(s: &str, dialect: &dyn Dialect) -> Result { + let mut parser = sqlparser::parser::Parser::new(dialect).try_with_sql(s)?; + let name = parser.parse_object_name(true)?; + let args = if parser.consume_token(&Token::LParen) { + parser.parse_optional_args()? + } else { + vec![] + }; + + let table_ref = match (name.0.first(), name.0.get(1), name.0.get(2)) { + (Some(catalog), Some(schema), Some(table)) => TableReference::full( + catalog.value.clone(), + schema.value.clone(), + table.value.clone(), + ), + (Some(schema), Some(table), None) => { + TableReference::partial(schema.value.clone(), table.value.clone()) + } + (Some(table), None, None) => TableReference::bare(table.value.clone()), + _ => { + return Err(DataFusionError::NotImplemented( + "Unable to parse string into TableReference".to_string(), + )) + } + }; + + if !args.is_empty() { + Ok(RemoteTableRef { + table_ref, + args: Some(args.into()), + }) + } else { + Ok(RemoteTableRef { + table_ref, + args: None, + }) + } + } + + pub fn table_ref(&self) -> &TableReference { + &self.table_ref + } + + pub fn args(&self) -> Option<&[FunctionArg]> { + self.args.as_deref() + } +} + +impl From for RemoteTableRef { + fn from(table_ref: TableReference) -> Self { + RemoteTableRef { + table_ref, + args: None, + } + } +} + +impl From for TableReference { + fn from(remote_table_ref: RemoteTableRef) -> Self { + remote_table_ref.table_ref + } +} + +impl From<&RemoteTableRef> for TableReference { + fn from(remote_table_ref: &RemoteTableRef) -> Self { + remote_table_ref.table_ref.clone() + } +} + +impl From<(TableReference, Vec)> for RemoteTableRef { + fn from((table_ref, args): (TableReference, Vec)) -> Self { + RemoteTableRef { + table_ref, + args: Some(args.into()), + } + } +} + +impl TryFrom<&str> for RemoteTableRef { + type Error = DataFusionError; + fn try_from(s: &str) -> Result { + Self::parse_with_default_dialect(s) + } +} + +impl TryFrom for RemoteTableRef { + type Error = DataFusionError; + fn try_from(s: String) -> Result { + Self::parse_with_default_dialect(&s) + } +} + +impl TryFrom<&String> for RemoteTableRef { + type Error = DataFusionError; + fn try_from(s: &String) -> Result { + Self::parse_with_default_dialect(s) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use sqlparser::{ + ast::{self, Expr, FunctionArgOperator, Ident, Value}, + dialect, + }; + + #[test] + fn bare_table_reference() { + let table_ref = RemoteTableRef::parse_with_default_dialect("table").unwrap(); + let expected = RemoteTableRef::from(TableReference::bare("table")); + assert_eq!(table_ref, expected); + + let table_ref = RemoteTableRef::parse_with_default_dialect("Table").unwrap(); + let expected = RemoteTableRef::from(TableReference::bare("Table")); + assert_eq!(table_ref, expected); + } + + #[test] + fn bare_table_reference_with_args() { + let table_ref = RemoteTableRef::parse_with_default_dialect("table(1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::bare("table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + + let table_ref = RemoteTableRef::parse_with_default_dialect("Table(1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::bare("Table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + } + + #[test] + fn bare_table_reference_with_args_and_whitespace() { + let table_ref = RemoteTableRef::parse_with_default_dialect("table (1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::bare("table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + + let table_ref = RemoteTableRef::parse_with_default_dialect("Table (1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::bare("Table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + } + + #[test] + fn multi_table_reference_with_no_args() { + let table_ref = RemoteTableRef::parse_with_default_dialect("schema.table").unwrap(); + let expected = RemoteTableRef::from(TableReference::partial("schema", "table")); + assert_eq!(table_ref, expected); + + let table_ref = RemoteTableRef::parse_with_default_dialect("schema.Table").unwrap(); + let expected = RemoteTableRef::from(TableReference::partial("schema", "Table")); + assert_eq!(table_ref, expected); + } + + #[test] + fn multi_table_reference_with_args() { + let table_ref = RemoteTableRef::parse_with_default_dialect("schema.table(1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::partial("schema", "table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + + let table_ref = RemoteTableRef::parse_with_default_dialect("schema.Table(1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::partial("schema", "Table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + } + + #[test] + fn multi_table_reference_with_args_and_whitespace() { + let table_ref = RemoteTableRef::parse_with_default_dialect("schema.table (1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::partial("schema", "table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + } + + #[test] + fn bare_reference_with_named_args() { + let table_ref = RemoteTableRef::parse_with_dialect( + "Table (user_id => 1, age => 2)", + &dialect::PostgreSqlDialect {}, + ) + .unwrap(); + let expected = RemoteTableRef::from(( + TableReference::bare("Table"), + vec![ + FunctionArg::ExprNamed { + name: ast::Expr::Identifier(Ident::new("user_id")), + arg: Expr::Value(Value::Number("1".to_string(), false)).into(), + operator: FunctionArgOperator::RightArrow, + }, + FunctionArg::ExprNamed { + name: ast::Expr::Identifier(Ident::new("age")), + arg: Expr::Value(Value::Number("2".to_string(), false)).into(), + operator: FunctionArgOperator::RightArrow, + }, + ], + )); + assert_eq!(table_ref, expected); + } +} From 2b1ca629b1f4030db50716457007f5d1414b4334 Mon Sep 17 00:00:00 2001 From: Satyam Singh Date: Fri, 11 Apr 2025 11:42:38 +0530 Subject: [PATCH 2/5] Update schema.rs --- datafusion-federation/src/sql/schema.rs | 73 +++++++++++-------------- 1 file changed, 31 insertions(+), 42 deletions(-) diff --git a/datafusion-federation/src/sql/schema.rs b/datafusion-federation/src/sql/schema.rs index a459d72..155b392 100644 --- a/datafusion-federation/src/sql/schema.rs +++ b/datafusion-federation/src/sql/schema.rs @@ -1,12 +1,7 @@ use std::{any::Any, sync::Arc}; use async_trait::async_trait; -use datafusion::{ - catalog::SchemaProvider, - datasource::TableProvider, - error::{DataFusionError, Result}, - sql::TableReference, -}; +use datafusion::{catalog::SchemaProvider, datasource::TableProvider, error::Result}; use futures::future::join_all; use super::{table::SQLTable, RemoteTableRef, SQLTableSource}; @@ -22,44 +17,27 @@ impl SQLSchemaProvider { /// Creates a new SQLSchemaProvider from a [`SQLFederationProvider`]. /// Initializes the schema provider by fetching table names and schema from the federation provider's executor, pub async fn new(provider: Arc) -> Result { - let executor = Arc::clone(&provider.executor); - let tables = executor + let tables = Arc::clone(&provider.executor) .table_names() .await? .iter() .map(RemoteTableRef::try_from) .collect::>>()?; - let tasks = tables - .into_iter() - .map(|table_ref| { - let provider = Arc::clone(&provider); - async move { SQLTableSource::new(provider, table_ref).await } - }) - .collect::>(); - - let tables = join_all(tasks) - .await - .into_iter() - .map(|res| res.map(Arc::new)) - .collect::>>()?; - - Ok(Self { tables }) + Self::new_with_table_references(provider, tables).await } /// Creates a new SQLSchemaProvider from a SQLFederationProvider and a list of table references. /// Fetches the schema for each table using the executor's implementation. - pub async fn new_with_tables( + pub async fn new_with_tables>( provider: Arc, - tables: Vec, - ) -> Result - where - T: TryInto, - { + tables: impl IntoIterator, + ) -> Result { let tables = tables .into_iter() - .map(|t| t.try_into()) - .collect::>>()?; + .map(|x| RemoteTableRef::try_from(x.as_ref())) + .collect::>>()?; + let futures: Vec<_> = tables .into_iter() .map(|t| SQLTableSource::new(Arc::clone(&provider), t)) @@ -82,6 +60,19 @@ impl SQLSchemaProvider { .collect(), } } + + pub async fn new_with_table_references( + provider: Arc, + tables: Vec, + ) -> Result { + let futures: Vec<_> = tables + .into_iter() + .map(|t| SQLTableSource::new(Arc::clone(&provider), t)) + .collect(); + let results: Result> = join_all(futures).await.into_iter().collect(); + let tables = results?.into_iter().map(Arc::new).collect(); + Ok(Self { tables }) + } } #[async_trait] @@ -93,16 +84,16 @@ impl SchemaProvider for SQLSchemaProvider { fn table_names(&self) -> Vec { self.tables .iter() - .map(|source| source.table_reference().to_quoted_string()) + .map(|source| source.table_reference().to_string()) .collect() } async fn table(&self, name: &str) -> Result>> { - if let Some(source) = self.tables.iter().find(|s| { - s.table_reference() - .to_quoted_string() - .eq_ignore_ascii_case(name) - }) { + if let Some(source) = self + .tables + .iter() + .find(|s| s.table_reference().to_string().eq(name)) + { let adaptor = FederatedTableProviderAdaptor::new(source.clone()); return Ok(Some(Arc::new(adaptor))); } @@ -110,11 +101,9 @@ impl SchemaProvider for SQLSchemaProvider { } fn table_exist(&self, name: &str) -> bool { - self.tables.iter().any(|source| { - source - .table_reference() - .resolved_eq(&TableReference::from(name)) - }) + self.tables + .iter() + .any(|source| source.table_reference().to_string().eq(name)) } } From dd4881326ea405eb78ed150d0ce7772c3a804477 Mon Sep 17 00:00:00 2001 From: Satyam Singh Date: Fri, 11 Apr 2025 19:58:53 +0530 Subject: [PATCH 3/5] Add test for multipart remote table --- datafusion-federation/src/sql/analyzer.rs | 90 +++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/datafusion-federation/src/sql/analyzer.rs b/datafusion-federation/src/sql/analyzer.rs index 1cfb0e8..9b8601a 100644 --- a/datafusion-federation/src/sql/analyzer.rs +++ b/datafusion-federation/src/sql/analyzer.rs @@ -880,4 +880,94 @@ mod tests { Ok(()) } + + fn get_multipart_test_table_provider() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Date32, false), + ])); + let table = Arc::new(TestTable::new("default.remote_table".to_string(), schema)); + let provider = Arc::new(SQLFederationProvider::new(Arc::new(TestExecutor))); + let table_source = Arc::new(SQLTableSource { provider, table }); + Arc::new(FederatedTableProviderAdaptor::new(table_source)) + } + + fn get_multipart_test_df_context() -> SessionContext { + let ctx = SessionContext::new(); + let catalog = ctx + .catalog("datafusion") + .expect("default catalog is datafusion"); + let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc; + catalog + .register_schema("foo", Arc::clone(&foo_schema)) + .expect("to register schema"); + foo_schema + .register_table("df_table".to_string(), get_multipart_test_table_provider()) + .expect("to register table"); + + let public_schema = catalog + .schema("public") + .expect("public schema should exist"); + public_schema + .register_table("app_table".to_string(), get_multipart_test_table_provider()) + .expect("to register table"); + + ctx + } + + #[tokio::test] + async fn test_rewrite_multipart_table() -> Result<()> { + init_tracing(); + let ctx = get_multipart_test_df_context(); + + let tests = vec![ + ( + "SELECT MAX(a) FROM foo.df_table", + r#"SELECT max(remote_table.a) FROM "default".remote_table"#, + ), + ( + "SELECT foo.df_table.a FROM foo.df_table", + r#"SELECT remote_table.a FROM "default".remote_table"#, + ), + ( + "SELECT MIN(a) FROM foo.df_table", + r#"SELECT min(remote_table.a) FROM "default".remote_table"#, + ), + ( + "SELECT AVG(a) FROM foo.df_table", + r#"SELECT avg(remote_table.a) FROM "default".remote_table"#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT count(remote_table.a) AS cnt FROM "default".remote_table"#, + ), + ( + "SELECT app_table from (SELECT a as app_table FROM app_table) b", + r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM "default".remote_table) AS b"#, + ), + ( + "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b", + r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM "default".remote_table) AS b"#, + ), + ( + "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM "default".remote_table)"#, + ), + ( + "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM "default".remote_table)"#, + ), + ( + "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)", + r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM "default".remote_table)"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } } From 204eeb6dda54c3b99190f19314bdf3fbe7166168 Mon Sep 17 00:00:00 2001 From: Satyam Singh Date: Fri, 11 Apr 2025 20:15:01 +0530 Subject: [PATCH 4/5] Add preserve existing alias test --- datafusion-federation/src/sql/analyzer.rs | 27 +++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/datafusion-federation/src/sql/analyzer.rs b/datafusion-federation/src/sql/analyzer.rs index 9b8601a..1d8498d 100644 --- a/datafusion-federation/src/sql/analyzer.rs +++ b/datafusion-federation/src/sql/analyzer.rs @@ -808,6 +808,33 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_rewrite_table_scans_preserve_existing_alias() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + ( + "SELECT b.a AS app_table_a FROM app_table AS b", + r#"SELECT b.a AS app_table_a FROM remote_table AS b"#, + ), + ( + "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table AS b)", + r#"SELECT app_table_a FROM (SELECT b.a AS app_table_a FROM remote_table AS b)"#, + ), + ( + "SELECT COUNT(b.a) FROM app_table AS b", + r#"SELECT count(b.a) FROM remote_table AS b"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } + async fn test_sql(ctx: &SessionContext, sql_query: &str, expected_sql: &str) -> Result<()> { let data_frame = ctx.sql(sql_query).await?; From 19b6b3d1ba18c22308b5e099e5e9286040e5b518 Mon Sep 17 00:00:00 2001 From: Satyam Singh Date: Fri, 11 Apr 2025 20:17:33 +0530 Subject: [PATCH 5/5] Preserver existing alias Adding alias is need for some of the older database versions --- datafusion-federation/src/sql/ast_analyzer.rs | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/datafusion-federation/src/sql/ast_analyzer.rs b/datafusion-federation/src/sql/ast_analyzer.rs index e7636e3..5242553 100644 --- a/datafusion-federation/src/sql/ast_analyzer.rs +++ b/datafusion-federation/src/sql/ast_analyzer.rs @@ -2,7 +2,8 @@ use std::ops::ControlFlow; use datafusion::sql::{ sqlparser::ast::{ - FunctionArg, ObjectName, Statement, TableFactor, TableFunctionArgs, VisitMut, VisitorMut, + FunctionArg, Ident, ObjectName, Statement, TableAlias, TableFactor, TableFunctionArgs, + VisitMut, VisitorMut, }, TableReference, }; @@ -83,14 +84,23 @@ impl VisitorMut for TableArgReplace { &mut self, table_factor: &mut TableFactor, ) -> ControlFlow { - if let TableFactor::Table { name, args, .. } = table_factor { + if let TableFactor::Table { + name, args, alias, .. + } = table_factor + { let name_as_tableref = name_to_table_reference(name); - if let Some(arg) = self + if let Some((table, arg)) = self .tables .iter() .find(|(t, _)| t.resolved_eq(&name_as_tableref)) { - *args = Some(arg.1.clone()); + *args = Some(arg.clone()); + if alias.is_none() { + *alias = Some(TableAlias { + name: Ident::new(table.table()), + columns: vec![], + }) + } } } ControlFlow::Continue(())