Skip to content

Commit 3b94956

Browse files
committed
fix: run HTTP in I/O runtime
Closes #140.
1 parent 22f8766 commit 3b94956

File tree

9 files changed

+528
-13
lines changed

9 files changed

+528
-13
lines changed

host/src/lib.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use ::http::HeaderName;
88
use arrow::datatypes::DataType;
99
use datafusion_common::{DataFusionError, Result as DataFusionResult};
1010
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature};
11-
use tokio::sync::Mutex;
11+
use tokio::{runtime::Handle, sync::Mutex};
1212
use wasmtime::{
1313
Engine, Store,
1414
component::{Component, ResourceAny},
@@ -72,6 +72,9 @@ struct WasmStateImpl {
7272

7373
/// HTTP request validator.
7474
http_validator: Arc<dyn HttpRequestValidator>,
75+
76+
/// Handle to tokio I/O runtime.
77+
io_rt: Handle,
7578
}
7679

7780
impl std::fmt::Debug for WasmStateImpl {
@@ -83,13 +86,15 @@ impl std::fmt::Debug for WasmStateImpl {
8386
wasi_http_ctx: _,
8487
resource_table,
8588
http_validator,
89+
io_rt,
8690
} = self;
8791
f.debug_struct("WasmStateImpl")
8892
.field("vfs_state", vfs_state)
8993
.field("stderr", stderr)
9094
.field("wasi_ctx", &"<WASI_CTX>")
9195
.field("resource_table", resource_table)
9296
.field("http_validator", http_validator)
97+
.field("io_rt", io_rt)
9398
.finish()
9499
}
95100
}
@@ -117,6 +122,8 @@ impl WasiHttpView for WasmStateImpl {
117122
mut request: hyper::Request<HyperOutgoingBody>,
118123
config: OutgoingRequestConfig,
119124
) -> HttpResult<HostFutureIncomingResponse> {
125+
let _guard = self.io_rt.enter();
126+
120127
// Python `requests` sends this so we allow it but later drop it from the actual request.
121128
request.headers_mut().remove(hyper::header::CONNECTION);
122129

@@ -298,6 +305,7 @@ impl WasmScalarUdf {
298305
pub async fn new(
299306
component: &WasmComponentPrecompiled,
300307
permissions: &WasmPermissions,
308+
io_rt: Handle,
301309
source: String,
302310
) -> DataFusionResult<Vec<Self>> {
303311
let WasmComponentPrecompiled { engine, component } = component;
@@ -314,6 +322,7 @@ impl WasmScalarUdf {
314322
wasi_http_ctx: WasiHttpCtx::new(),
315323
resource_table: ResourceTable::new(),
316324
http_validator: Arc::clone(&permissions.http),
325+
io_rt,
317326
};
318327
let (bindings, mut store) = link(engine, component, state)
319328
.await

host/src/udf_query.rs

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
//! Embedded SQL approach for executing UDFs within SQL queries.
2+
3+
use std::collections::HashMap;
4+
5+
use datafusion::execution::TaskContext;
6+
use datafusion_common::{DataFusionError, Result as DataFusionResult};
7+
use datafusion_sql::parser::{DFParserBuilder, Statement};
8+
use sqlparser::ast::{CreateFunctionBody, Expr, Statement as SqlStatement, Value};
9+
use sqlparser::dialect::dialect_from_str;
10+
use tokio::runtime::Handle;
11+
12+
use crate::{WasmComponentPrecompiled, WasmPermissions, WasmScalarUdf};
13+
14+
/// A [ParsedQuery] contains the extracted UDFs and SQL query string
15+
#[derive(Debug)]
16+
pub struct ParsedQuery {
17+
/// Extracted UDFs from the query
18+
pub udfs: Vec<WasmScalarUdf>,
19+
/// SQL query string with UDF definitions removed
20+
pub sql: String,
21+
}
22+
23+
/// Handles the registration and invocation of UDF queries in DataFusion with a
24+
/// pre-compiled WASM component.
25+
pub struct UdfQueryParser<'a> {
26+
/// Pre-compiled WASM component.
27+
/// Necessary to create UDFs.
28+
components: HashMap<String, &'a WasmComponentPrecompiled>,
29+
}
30+
31+
impl std::fmt::Debug for UdfQueryParser<'_> {
32+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33+
f.debug_struct("UdfQueryParser")
34+
.field("session_ctx", &"SessionContext { ... }")
35+
.field("components", &self.components)
36+
.finish()
37+
}
38+
}
39+
40+
impl<'a> UdfQueryParser<'a> {
41+
/// Registers the UDF query in DataFusion.
42+
pub fn new(components: HashMap<String, &'a WasmComponentPrecompiled>) -> Self {
43+
Self { components }
44+
}
45+
46+
/// Parses a SQL query that defines & uses UDFs into a [ParsedQuery].
47+
pub async fn parse(
48+
&self,
49+
udf_query: &str,
50+
permissions: &WasmPermissions,
51+
io_rt: Handle,
52+
task_ctx: &TaskContext,
53+
) -> DataFusionResult<ParsedQuery> {
54+
let (code, sql) = self.parse_inner(udf_query, task_ctx)?;
55+
56+
let mut udfs = vec![];
57+
for (lang, blocks) in code {
58+
let component = self.components.get(&lang).ok_or_else(|| {
59+
DataFusionError::Plan(format!(
60+
"no WASM component registered for language: {:?}",
61+
lang
62+
))
63+
})?;
64+
65+
for code in blocks {
66+
udfs.extend(WasmScalarUdf::new(component, permissions, io_rt.clone(), code).await?);
67+
}
68+
}
69+
70+
Ok(ParsedQuery { udfs, sql })
71+
}
72+
73+
/// Parse the combined query to extract the chosen UDF language, UDF
74+
/// definitions, and SQL statements.
75+
fn parse_inner(
76+
&self,
77+
query: &str,
78+
task_ctx: &TaskContext,
79+
) -> DataFusionResult<(HashMap<String, Vec<String>>, String)> {
80+
let options = task_ctx.session_config().options();
81+
82+
let dialect = dialect_from_str(options.sql_parser.dialect.clone()).expect("valid dialect");
83+
let recursion_limit = options.sql_parser.recursion_limit;
84+
85+
let statements = DFParserBuilder::new(query)
86+
.with_dialect(dialect.as_ref())
87+
.with_recursion_limit(recursion_limit)
88+
.build()?
89+
.parse_statements()?;
90+
91+
let mut sql = String::new();
92+
let mut udf_blocks: HashMap<String, Vec<String>> = HashMap::new();
93+
for s in statements {
94+
let Statement::Statement(stmt) = s else {
95+
continue;
96+
};
97+
98+
match parse_udf(*stmt)? {
99+
Parsed::Udf { code, language } => {
100+
if let Some(existing) = udf_blocks.get_mut(&language) {
101+
existing.push(code);
102+
} else {
103+
udf_blocks.insert(language.clone(), vec![code]);
104+
}
105+
}
106+
Parsed::Other(statement) => {
107+
sql.push_str(&statement);
108+
sql.push_str(";\n");
109+
}
110+
}
111+
}
112+
113+
if sql.is_empty() {
114+
return Err(DataFusionError::Plan("no SQL query found".to_string()));
115+
}
116+
117+
Ok((udf_blocks, sql))
118+
}
119+
}
120+
121+
/// Represents a parsed SQL statement
122+
enum Parsed {
123+
/// A UDF definition
124+
Udf {
125+
/// UDF code
126+
code: String,
127+
/// UDF language
128+
language: String,
129+
},
130+
/// Any other SQL statement
131+
Other(String),
132+
}
133+
134+
/// Parse a single SQL statement to extract a UDF
135+
fn parse_udf(stmt: SqlStatement) -> DataFusionResult<Parsed> {
136+
match stmt {
137+
SqlStatement::CreateFunction(cf) => {
138+
let function_body = cf.function_body.as_ref();
139+
140+
let language = if let Some(lang) = cf.language.as_ref() {
141+
lang.to_string()
142+
} else {
143+
return Err(DataFusionError::Plan(
144+
"function language is required for UDFs".to_string(),
145+
));
146+
};
147+
148+
let code = match function_body {
149+
Some(body) => extract_function_body(body),
150+
None => Err(DataFusionError::Plan(
151+
"function body is required for UDFs".to_string(),
152+
)),
153+
}?;
154+
155+
Ok(Parsed::Udf {
156+
code: code.to_string(),
157+
language,
158+
})
159+
}
160+
_ => Ok(Parsed::Other(stmt.to_string())),
161+
}
162+
}
163+
164+
/// Extracts the code from the function body, adding it to `code`.
165+
fn extract_function_body(body: &CreateFunctionBody) -> DataFusionResult<&str> {
166+
match body {
167+
CreateFunctionBody::AsAfterOptions(e) | CreateFunctionBody::AsBeforeOptions(e) => {
168+
expression_into_str(e)
169+
}
170+
CreateFunctionBody::Return(_) => Err(DataFusionError::Plan(
171+
"`RETURN` function body not supported for UDFs".to_string(),
172+
)),
173+
}
174+
}
175+
176+
/// Attempt to convert an `Expr` into a `str`
177+
fn expression_into_str(expr: &Expr) -> DataFusionResult<&str> {
178+
match expr {
179+
Expr::Value(v) => match &v.value {
180+
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(s),
181+
_ => Err(DataFusionError::Plan("expected string value".to_string())),
182+
},
183+
_ => Err(DataFusionError::Plan(
184+
"expected value expression".to_string(),
185+
)),
186+
}
187+
}

host/tests/integration_tests/python/runtime/fs.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use arrow::{
66
};
77
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
88
use datafusion_udf_wasm_host::{WasmPermissions, WasmScalarUdf, vfs::VfsLimits};
9+
use tokio::runtime::Handle;
910

1011
use crate::integration_tests::{
1112
python::test_utils::{python_component, python_scalar_udf},
@@ -239,6 +240,7 @@ async fn test_limit_inodes() {
239240
inodes: 42,
240241
..Default::default()
241242
}),
243+
Handle::current(),
242244
"".to_owned(),
243245
)
244246
.await
@@ -260,6 +262,7 @@ async fn test_limit_bytes() {
260262
bytes: 1337,
261263
..Default::default()
262264
}),
265+
Handle::current(),
263266
"".to_owned(),
264267
)
265268
.await

host/tests/integration_tests/python/runtime/http.rs

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::sync::Arc;
1+
use std::{sync::Arc, time::Duration};
22

33
use arrow::{
44
array::{Array, StringArray, StringBuilder},
@@ -10,6 +10,7 @@ use datafusion_udf_wasm_host::{
1010
WasmPermissions, WasmScalarUdf,
1111
http::{AllowCertainHttpRequests, HttpRequestValidator, Matcher},
1212
};
13+
use tokio::runtime::Handle;
1314
use wasmtime_wasi_http::types::DEFAULT_FORBIDDEN_HEADERS;
1415
use wiremock::{Mock, MockServer, ResponseTemplate, matchers};
1516

@@ -568,6 +569,7 @@ where
568569
WasmScalarUdf::new(
569570
python_component().await,
570571
&WasmPermissions::new().with_http(permissions),
572+
Handle::current(),
571573
code.to_owned(),
572574
)
573575
.await
@@ -582,3 +584,79 @@ where
582584
assert_eq!(udfs.len(), 1);
583585
udfs.into_iter().next().unwrap()
584586
}
587+
588+
#[test]
589+
fn test_io_runtime() {
590+
const CODE: &str = r#"
591+
import urllib3
592+
593+
def perform_request(url: str) -> str:
594+
resp = urllib3.request("GET", url)
595+
return resp.data.decode("utf-8")
596+
"#;
597+
598+
let rt_tmp = tokio::runtime::Builder::new_current_thread()
599+
.build()
600+
.unwrap();
601+
let rt_cpu = tokio::runtime::Builder::new_multi_thread()
602+
.worker_threads(1)
603+
// It would be nice if all the timeouts-related timers would also run within the within the I/O runtime, but
604+
// that requires some larger intervention (either upstream or with a custom WASI HTTP implementation).
605+
// Hence, we don't do that yet.
606+
.enable_time()
607+
.build()
608+
.unwrap();
609+
let rt_io = tokio::runtime::Builder::new_multi_thread()
610+
.worker_threads(1)
611+
.enable_all()
612+
.build()
613+
.unwrap();
614+
615+
let server = rt_io.block_on(async {
616+
let server = MockServer::start().await;
617+
Mock::given(matchers::any())
618+
.respond_with(ResponseTemplate::new(200).set_body_string("hello world!"))
619+
.expect(1)
620+
.mount(&server)
621+
.await;
622+
server
623+
});
624+
625+
// deliberately use a runtime what we are going to throw away later to prevent tricks like `Handle::current`
626+
let udf = rt_tmp.block_on(async {
627+
let mut permissions = AllowCertainHttpRequests::new();
628+
permissions.allow(Matcher {
629+
method: http::Method::GET,
630+
host: server.address().ip().to_string().into(),
631+
port: server.address().port(),
632+
});
633+
634+
let udfs = WasmScalarUdf::new(
635+
python_component().await,
636+
&WasmPermissions::new().with_http(permissions),
637+
rt_io.handle().clone(),
638+
CODE.to_owned(),
639+
)
640+
.await
641+
.unwrap();
642+
assert_eq!(udfs.len(), 1);
643+
udfs.into_iter().next().unwrap()
644+
});
645+
rt_tmp.shutdown_timeout(Duration::from_secs(1));
646+
647+
let array = rt_cpu.block_on(async {
648+
udf.invoke_with_args(ScalarFunctionArgs {
649+
args: vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(server.uri())))],
650+
arg_fields: vec![Arc::new(Field::new("uri", DataType::Utf8, true))],
651+
number_rows: 1,
652+
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
653+
})
654+
.unwrap()
655+
.unwrap_array()
656+
});
657+
658+
assert_eq!(
659+
array.as_ref(),
660+
&StringArray::from_iter([Some("hello world!".to_owned()),]) as &dyn Array,
661+
);
662+
}

host/tests/integration_tests/python/test_utils.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use datafusion_common::DataFusionError;
22
use datafusion_udf_wasm_host::{WasmComponentPrecompiled, WasmScalarUdf};
3-
use tokio::sync::OnceCell;
3+
use tokio::{runtime::Handle, sync::OnceCell};
44

55
/// Static precompiled Python WASM component for tests
66
static COMPONENT: OnceCell<WasmComponentPrecompiled> = OnceCell::const_new();
@@ -20,7 +20,13 @@ pub(crate) async fn python_component() -> &'static WasmComponentPrecompiled {
2020
pub(crate) async fn python_scalar_udfs(code: &str) -> Result<Vec<WasmScalarUdf>, DataFusionError> {
2121
let component = python_component().await;
2222

23-
WasmScalarUdf::new(component, &Default::default(), code.to_owned()).await
23+
WasmScalarUdf::new(
24+
component,
25+
&Default::default(),
26+
Handle::current(),
27+
code.to_owned(),
28+
)
29+
.await
2430
}
2531

2632
/// Compiles the provided Python UDF code into a single WasmScalarUdf instance.

0 commit comments

Comments
 (0)