Skip to content

Commit 645e9d7

Browse files
committed
Quick cleanup
1 parent c66a5b5 commit 645e9d7

File tree

2 files changed

+76
-45
lines changed

2 files changed

+76
-45
lines changed

datafusion-flight-sql-server/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ once_cell = "1.21"
2727
prost = "0.13"
2828
tonic.workspace = true
2929
async-trait.workspace = true
30+
tonic-async-interceptor = "0.12.0"
3031

3132
[dev-dependencies]
3233
tokio.workspace = true

datafusion-flight-sql-server/examples/bearer_auth_flight_sql.rs

Lines changed: 75 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,20 @@
2424
//! Observe the server console output for messages from the interceptor and session provider,
2525
//! and the client output for the success/failure of each attempt.
2626
27-
use std::sync::Arc;
2827
use std::time::Duration;
2928

29+
use arrow_flight::flight_service_server::FlightServiceServer;
3030
use arrow_flight::sql::client::FlightSqlServiceClient;
3131
use arrow_flight::sql::CommandGetTables;
3232
use async_trait::async_trait;
3333
use datafusion::error::{DataFusionError, Result};
34+
use datafusion::execution::context::SessionState; // SessionState is not in prelude
3435
use datafusion::prelude::*; // Covers SessionContext, CsvReadOptions, etc.
35-
use datafusion::execution::context::{SessionConfig, SessionState}; // SessionState is not in prelude
36-
use datafusion::execution::runtime_env::RuntimeEnv;
3736
use datafusion_flight_sql_server::service::FlightSqlService;
3837
use datafusion_flight_sql_server::session::SessionStateProvider;
3938
use tokio::time::sleep;
4039
use tonic::transport::{Channel, Endpoint, Server};
41-
use tonic::{metadata::MetadataValue, Request, Status};
40+
use tonic::{Request, Status};
4241

4342
// UserData struct remains the same
4443
#[derive(Clone, Debug)]
@@ -76,24 +75,27 @@ async fn bearer_auth_interceptor(mut req: Request<()>) -> Result<Request<()>, St
7675
}
7776

7877
// Updated MySessionStateProvider
79-
#[derive(Debug, Clone)] // Removed Default
78+
// #[derive(Debug, Clone)] // Removed Default
8079
pub struct MySessionStateProvider {
8180
base_context: SessionContext,
8281
}
8382

8483
impl MySessionStateProvider {
85-
async fn try_new() -> Result<Self> { // datafusion::error::Result
84+
async fn try_new() -> Result<Self> {
85+
// datafusion::error::Result
8686
let ctx = SessionContext::new();
8787
// Construct path to test.csv relative to CARGO_MANIFEST_DIR of the datafusion-flight-sql-server crate
8888
let csv_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/test.csv");
89-
ctx.register_csv("test", csv_path, CsvReadOptions::new()).await?;
89+
ctx.register_csv("test", csv_path, CsvReadOptions::new())
90+
.await?;
9091
Ok(Self { base_context: ctx })
9192
}
9293
}
9394

9495
#[async_trait]
9596
impl SessionStateProvider for MySessionStateProvider {
96-
async fn new_context(&self, request: &Request<()>) -> Result<SessionState, Status> { // tonic::Result for Status
97+
async fn new_context(&self, request: &Request<()>) -> Result<SessionState, Status> {
98+
// tonic::Result for Status
9799
if let Some(user_data) = request.extensions().get::<UserData>() {
98100
println!(
99101
"Session context for user_id: {}. Cloning base context.",
@@ -115,46 +117,48 @@ impl SessionStateProvider for MySessionStateProvider {
115117
async fn new_client_with_auth(
116118
dsn: String,
117119
token: Option<String>,
118-
) -> Result<FlightSqlServiceClient<Channel>> { // datafusion::error::Result
120+
) -> Result<FlightSqlServiceClient<Channel>> {
121+
// datafusion::error::Result
119122
let endpoint = Endpoint::from_shared(dsn.clone())
120123
.map_err(|e| DataFusionError::External(format!("Invalid DSN {}: {}", dsn, e).into()))?
121124
.connect_timeout(std::time::Duration::from_secs(10));
122125

123-
let channel = endpoint.connect().await
124-
.map_err(|e| DataFusionError::External(format!("Failed to connect to {}: {}", dsn, e).into()))?;
125-
126-
let service_client = FlightSqlServiceClient::with_interceptor(
127-
channel,
128-
move |mut req: Request<()>| {
129-
if let Some(token_str) = token.clone() {
130-
let bearer_token = format!("Bearer {}", token_str);
131-
match MetadataValue::try_from(&bearer_token) {
132-
Ok(metadata_val) => req.metadata_mut().insert("authorization", metadata_val),
133-
Err(_) => return Err(Status::invalid_argument("Invalid token format for metadata")),
134-
};
135-
}
136-
Ok(req)
137-
},
138-
);
139-
Ok(service_client)
140-
}
126+
let channel = endpoint.connect().await.map_err(|e| {
127+
DataFusionError::External(format!("Failed to connect to {}: {}", dsn, e).into())
128+
})?;
141129

142-
fn status_to_df_error(err: tonic::Status) -> DataFusionError {
143-
DataFusionError::External(format!("Tonic status error: {}", err).into())
130+
let mut service_client = FlightSqlServiceClient::new(channel);
131+
if let Some(token_str) = token.clone() {
132+
service_client.set_header("authorization", format!("Bearer {}", token_str));
133+
}
134+
Ok(service_client)
144135
}
145136

146137
#[tokio::main]
147-
async fn main() -> Result<()> { // datafusion::error::Result<()>
138+
async fn main() -> Result<()> {
139+
// datafusion::error::Result<()>
148140
// Server Setup
149141
let dsn: String = "0.0.0.0:50051".to_string();
150-
let state_provider = Arc::new(MySessionStateProvider::try_new().await?);
151-
let base_service = FlightSqlService::new_with_state_provider(state_provider).into_service();
152-
let wrapped_service = tonic::service::interceptor_fn(base_service, bearer_auth_interceptor);
153-
let addr: std::net::SocketAddr = dsn.parse().map_err(|e| DataFusionError::External(format!("Invalid address format {}: {}", dsn, e).into()))?;
142+
let state_provider = Box::new(MySessionStateProvider::try_new().await?);
143+
let base_service = FlightSqlService::new_with_provider(state_provider);
144+
let svc: FlightServiceServer<FlightSqlService> = FlightServiceServer::new(base_service);
145+
let addr: std::net::SocketAddr = dsn.parse().map_err(|e| {
146+
DataFusionError::External(format!("Invalid address format {}: {}", dsn, e).into())
147+
})?;
154148

155149
tokio::spawn(async move {
156-
println!("Bearer Authentication Flight SQL server listening on {}", addr);
157-
if let Err(e) = Server::builder().add_service(wrapped_service).serve(addr).await {
150+
println!(
151+
"Bearer Authentication Flight SQL server listening on {}",
152+
addr
153+
);
154+
if let Err(e) = Server::builder()
155+
.layer(tonic_async_interceptor::async_interceptor(
156+
bearer_auth_interceptor,
157+
))
158+
.add_service(svc)
159+
.serve(addr)
160+
.await
161+
{
158162
eprintln!("Server error: {}", e);
159163
}
160164
});
@@ -169,10 +173,18 @@ async fn main() -> Result<()> { // datafusion::error::Result<()>
169173
println!("\nAttempting GetTables with valid token (token1)...");
170174
match new_client_with_auth(client_dsn.clone(), Some("token1".to_string())).await {
171175
Ok(mut client) => {
172-
let request = CommandGetTables { catalog: None, db_schema_filter_pattern: None, table_name_filter_pattern: None, table_types: vec![], include_schema: false };
176+
let request = CommandGetTables {
177+
catalog: None,
178+
db_schema_filter_pattern: None,
179+
table_name_filter_pattern: None,
180+
table_types: vec![],
181+
include_schema: false,
182+
};
173183
match client.get_tables(request).await {
174-
Ok(response) => println!("GetTables with token1 SUCCEEDED. Response: {:?}", response.into_inner()),
175-
Err(e) => eprintln!("GetTables with token1 FAILED: {}", status_to_df_error(e)),
184+
Ok(response) => {
185+
println!("GetTables with token1 SUCCEEDED. Response: {:?}", response)
186+
}
187+
Err(e) => eprintln!("GetTables with token1 FAILED: {}", e),
176188
}
177189
}
178190
Err(e) => eprintln!("Failed to create client with token1: {}", e),
@@ -182,10 +194,19 @@ async fn main() -> Result<()> { // datafusion::error::Result<()>
182194
println!("\nAttempting GetTables with invalid token (invalidtoken)...");
183195
match new_client_with_auth(client_dsn.clone(), Some("invalidtoken".to_string())).await {
184196
Ok(mut client) => {
185-
let request = CommandGetTables { catalog: None, db_schema_filter_pattern: None, table_name_filter_pattern: None, table_types: vec![], include_schema: false };
197+
let request = CommandGetTables {
198+
catalog: None,
199+
db_schema_filter_pattern: None,
200+
table_name_filter_pattern: None,
201+
table_types: vec![],
202+
include_schema: false,
203+
};
186204
match client.get_tables(request).await {
187-
Ok(response) => println!("GetTables with invalidtoken SUCCEEDED (unexpected). Response: {:?}", response.into_inner()),
188-
Err(e) => eprintln!("GetTables with invalidtoken FAILED (as expected): {}", status_to_df_error(e)),
205+
Ok(response) => println!(
206+
"GetTables with invalidtoken SUCCEEDED (unexpected). Response: {:?}",
207+
response
208+
),
209+
Err(e) => eprintln!("GetTables with invalidtoken FAILED (as expected): {:?}", e),
189210
}
190211
}
191212
Err(e) => eprintln!("Failed to create client with invalidtoken: {}", e),
@@ -195,10 +216,19 @@ async fn main() -> Result<()> { // datafusion::error::Result<()>
195216
println!("\nAttempting GetTables with no token...");
196217
match new_client_with_auth(client_dsn.clone(), None).await {
197218
Ok(mut client) => {
198-
let request = CommandGetTables { catalog: None, db_schema_filter_pattern: None, table_name_filter_pattern: None, table_types: vec![], include_schema: false };
219+
let request = CommandGetTables {
220+
catalog: None,
221+
db_schema_filter_pattern: None,
222+
table_name_filter_pattern: None,
223+
table_types: vec![],
224+
include_schema: false,
225+
};
199226
match client.get_tables(request).await {
200-
Ok(response) => println!("GetTables with no token SUCCEEDED (unexpected). Response: {:?}", response.into_inner()),
201-
Err(e) => eprintln!("GetTables with no token FAILED (as expected): {}", status_to_df_error(e)),
227+
Ok(response) => println!(
228+
"GetTables with no token SUCCEEDED (unexpected). Response: {:?}",
229+
response
230+
),
231+
Err(e) => eprintln!("GetTables with no token FAILED (as expected): {:?}", e),
202232
}
203233
}
204234
Err(e) => eprintln!("Failed to create client with no token: {}", e),

0 commit comments

Comments
 (0)