|
| 1 | +//! Embedded SQL approach for executing UDFs within SQL queries. |
| 2 | +
|
| 3 | +use std::collections::HashMap; |
| 4 | + |
| 5 | +use datafusion::execution::TaskContext; |
| 6 | +use datafusion_common::{DataFusionError, Result as DataFusionResult}; |
| 7 | +use datafusion_sql::parser::{DFParserBuilder, Statement}; |
| 8 | +use sqlparser::ast::{CreateFunctionBody, Expr, Statement as SqlStatement, Value}; |
| 9 | +use sqlparser::dialect::dialect_from_str; |
| 10 | +use tokio::runtime::Handle; |
| 11 | + |
| 12 | +use crate::{WasmComponentPrecompiled, WasmPermissions, WasmScalarUdf}; |
| 13 | + |
| 14 | +/// A [ParsedQuery] contains the extracted UDFs and SQL query string |
| 15 | +#[derive(Debug)] |
| 16 | +pub struct ParsedQuery { |
| 17 | + /// Extracted UDFs from the query |
| 18 | + pub udfs: Vec<WasmScalarUdf>, |
| 19 | + /// SQL query string with UDF definitions removed |
| 20 | + pub sql: String, |
| 21 | +} |
| 22 | + |
| 23 | +/// Handles the registration and invocation of UDF queries in DataFusion with a |
| 24 | +/// pre-compiled WASM component. |
| 25 | +pub struct UdfQueryParser<'a> { |
| 26 | + /// Pre-compiled WASM component. |
| 27 | + /// Necessary to create UDFs. |
| 28 | + components: HashMap<String, &'a WasmComponentPrecompiled>, |
| 29 | +} |
| 30 | + |
| 31 | +impl std::fmt::Debug for UdfQueryParser<'_> { |
| 32 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 33 | + f.debug_struct("UdfQueryParser") |
| 34 | + .field("session_ctx", &"SessionContext { ... }") |
| 35 | + .field("components", &self.components) |
| 36 | + .finish() |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | +impl<'a> UdfQueryParser<'a> { |
| 41 | + /// Registers the UDF query in DataFusion. |
| 42 | + pub fn new(components: HashMap<String, &'a WasmComponentPrecompiled>) -> Self { |
| 43 | + Self { components } |
| 44 | + } |
| 45 | + |
| 46 | + /// Parses a SQL query that defines & uses UDFs into a [ParsedQuery]. |
| 47 | + pub async fn parse( |
| 48 | + &self, |
| 49 | + udf_query: &str, |
| 50 | + permissions: &WasmPermissions, |
| 51 | + io_rt: Handle, |
| 52 | + task_ctx: &TaskContext, |
| 53 | + ) -> DataFusionResult<ParsedQuery> { |
| 54 | + let (code, sql) = self.parse_inner(udf_query, task_ctx)?; |
| 55 | + |
| 56 | + let mut udfs = vec![]; |
| 57 | + for (lang, blocks) in code { |
| 58 | + let component = self.components.get(&lang).ok_or_else(|| { |
| 59 | + DataFusionError::Plan(format!( |
| 60 | + "no WASM component registered for language: {:?}", |
| 61 | + lang |
| 62 | + )) |
| 63 | + })?; |
| 64 | + |
| 65 | + for code in blocks { |
| 66 | + udfs.extend(WasmScalarUdf::new(component, permissions, io_rt.clone(), code).await?); |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + Ok(ParsedQuery { udfs, sql }) |
| 71 | + } |
| 72 | + |
| 73 | + /// Parse the combined query to extract the chosen UDF language, UDF |
| 74 | + /// definitions, and SQL statements. |
| 75 | + fn parse_inner( |
| 76 | + &self, |
| 77 | + query: &str, |
| 78 | + task_ctx: &TaskContext, |
| 79 | + ) -> DataFusionResult<(HashMap<String, Vec<String>>, String)> { |
| 80 | + let options = task_ctx.session_config().options(); |
| 81 | + |
| 82 | + let dialect = dialect_from_str(options.sql_parser.dialect.clone()).expect("valid dialect"); |
| 83 | + let recursion_limit = options.sql_parser.recursion_limit; |
| 84 | + |
| 85 | + let statements = DFParserBuilder::new(query) |
| 86 | + .with_dialect(dialect.as_ref()) |
| 87 | + .with_recursion_limit(recursion_limit) |
| 88 | + .build()? |
| 89 | + .parse_statements()?; |
| 90 | + |
| 91 | + let mut sql = String::new(); |
| 92 | + let mut udf_blocks: HashMap<String, Vec<String>> = HashMap::new(); |
| 93 | + for s in statements { |
| 94 | + let Statement::Statement(stmt) = s else { |
| 95 | + continue; |
| 96 | + }; |
| 97 | + |
| 98 | + match parse_udf(*stmt)? { |
| 99 | + Parsed::Udf { code, language } => { |
| 100 | + if let Some(existing) = udf_blocks.get_mut(&language) { |
| 101 | + existing.push(code); |
| 102 | + } else { |
| 103 | + udf_blocks.insert(language.clone(), vec![code]); |
| 104 | + } |
| 105 | + } |
| 106 | + Parsed::Other(statement) => { |
| 107 | + sql.push_str(&statement); |
| 108 | + sql.push_str(";\n"); |
| 109 | + } |
| 110 | + } |
| 111 | + } |
| 112 | + |
| 113 | + if sql.is_empty() { |
| 114 | + return Err(DataFusionError::Plan("no SQL query found".to_string())); |
| 115 | + } |
| 116 | + |
| 117 | + Ok((udf_blocks, sql)) |
| 118 | + } |
| 119 | +} |
| 120 | + |
| 121 | +/// Represents a parsed SQL statement |
| 122 | +enum Parsed { |
| 123 | + /// A UDF definition |
| 124 | + Udf { |
| 125 | + /// UDF code |
| 126 | + code: String, |
| 127 | + /// UDF language |
| 128 | + language: String, |
| 129 | + }, |
| 130 | + /// Any other SQL statement |
| 131 | + Other(String), |
| 132 | +} |
| 133 | + |
| 134 | +/// Parse a single SQL statement to extract a UDF |
| 135 | +fn parse_udf(stmt: SqlStatement) -> DataFusionResult<Parsed> { |
| 136 | + match stmt { |
| 137 | + SqlStatement::CreateFunction(cf) => { |
| 138 | + let function_body = cf.function_body.as_ref(); |
| 139 | + |
| 140 | + let language = if let Some(lang) = cf.language.as_ref() { |
| 141 | + lang.to_string() |
| 142 | + } else { |
| 143 | + return Err(DataFusionError::Plan( |
| 144 | + "function language is required for UDFs".to_string(), |
| 145 | + )); |
| 146 | + }; |
| 147 | + |
| 148 | + let code = match function_body { |
| 149 | + Some(body) => extract_function_body(body), |
| 150 | + None => Err(DataFusionError::Plan( |
| 151 | + "function body is required for UDFs".to_string(), |
| 152 | + )), |
| 153 | + }?; |
| 154 | + |
| 155 | + Ok(Parsed::Udf { |
| 156 | + code: code.to_string(), |
| 157 | + language, |
| 158 | + }) |
| 159 | + } |
| 160 | + _ => Ok(Parsed::Other(stmt.to_string())), |
| 161 | + } |
| 162 | +} |
| 163 | + |
| 164 | +/// Extracts the code from the function body, adding it to `code`. |
| 165 | +fn extract_function_body(body: &CreateFunctionBody) -> DataFusionResult<&str> { |
| 166 | + match body { |
| 167 | + CreateFunctionBody::AsAfterOptions(e) | CreateFunctionBody::AsBeforeOptions(e) => { |
| 168 | + expression_into_str(e) |
| 169 | + } |
| 170 | + CreateFunctionBody::Return(_) => Err(DataFusionError::Plan( |
| 171 | + "`RETURN` function body not supported for UDFs".to_string(), |
| 172 | + )), |
| 173 | + } |
| 174 | +} |
| 175 | + |
| 176 | +/// Attempt to convert an `Expr` into a `str` |
| 177 | +fn expression_into_str(expr: &Expr) -> DataFusionResult<&str> { |
| 178 | + match expr { |
| 179 | + Expr::Value(v) => match &v.value { |
| 180 | + Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(s), |
| 181 | + _ => Err(DataFusionError::Plan("expected string value".to_string())), |
| 182 | + }, |
| 183 | + _ => Err(DataFusionError::Plan( |
| 184 | + "expected value expression".to_string(), |
| 185 | + )), |
| 186 | + } |
| 187 | +} |
0 commit comments