24
24
//! Observe the server console output for messages from the interceptor and session provider,
25
25
//! and the client output for the success/failure of each attempt.
26
26
27
- use std:: sync:: Arc ;
28
27
use std:: time:: Duration ;
29
28
29
+ use arrow_flight:: flight_service_server:: FlightServiceServer ;
30
30
use arrow_flight:: sql:: client:: FlightSqlServiceClient ;
31
31
use arrow_flight:: sql:: CommandGetTables ;
32
32
use async_trait:: async_trait;
33
33
use datafusion:: error:: { DataFusionError , Result } ;
34
+ use datafusion:: execution:: context:: SessionState ; // SessionState is not in prelude
34
35
use datafusion:: prelude:: * ; // Covers SessionContext, CsvReadOptions, etc.
35
- use datafusion:: execution:: context:: { SessionConfig , SessionState } ; // SessionState is not in prelude
36
- use datafusion:: execution:: runtime_env:: RuntimeEnv ;
37
36
use datafusion_flight_sql_server:: service:: FlightSqlService ;
38
37
use datafusion_flight_sql_server:: session:: SessionStateProvider ;
39
38
use tokio:: time:: sleep;
40
39
use tonic:: transport:: { Channel , Endpoint , Server } ;
41
- use tonic:: { metadata :: MetadataValue , Request , Status } ;
40
+ use tonic:: { Request , Status } ;
42
41
43
42
// UserData struct remains the same
44
43
#[ derive( Clone , Debug ) ]
@@ -76,24 +75,27 @@ async fn bearer_auth_interceptor(mut req: Request<()>) -> Result<Request<()>, St
76
75
}
77
76
78
77
// Updated MySessionStateProvider
79
- #[ derive( Debug , Clone ) ] // Removed Default
78
+ // #[derive(Debug, Clone)] // Removed Default
80
79
pub struct MySessionStateProvider {
81
80
base_context : SessionContext ,
82
81
}
83
82
84
83
impl MySessionStateProvider {
85
- async fn try_new ( ) -> Result < Self > { // datafusion::error::Result
84
+ async fn try_new ( ) -> Result < Self > {
85
+ // datafusion::error::Result
86
86
let ctx = SessionContext :: new ( ) ;
87
87
// Construct path to test.csv relative to CARGO_MANIFEST_DIR of the datafusion-flight-sql-server crate
88
88
let csv_path = concat ! ( env!( "CARGO_MANIFEST_DIR" ) , "/examples/test.csv" ) ;
89
- ctx. register_csv ( "test" , csv_path, CsvReadOptions :: new ( ) ) . await ?;
89
+ ctx. register_csv ( "test" , csv_path, CsvReadOptions :: new ( ) )
90
+ . await ?;
90
91
Ok ( Self { base_context : ctx } )
91
92
}
92
93
}
93
94
94
95
#[ async_trait]
95
96
impl SessionStateProvider for MySessionStateProvider {
96
- async fn new_context ( & self , request : & Request < ( ) > ) -> Result < SessionState , Status > { // tonic::Result for Status
97
+ async fn new_context ( & self , request : & Request < ( ) > ) -> Result < SessionState , Status > {
98
+ // tonic::Result for Status
97
99
if let Some ( user_data) = request. extensions ( ) . get :: < UserData > ( ) {
98
100
println ! (
99
101
"Session context for user_id: {}. Cloning base context." ,
@@ -115,46 +117,48 @@ impl SessionStateProvider for MySessionStateProvider {
115
117
async fn new_client_with_auth (
116
118
dsn : String ,
117
119
token : Option < String > ,
118
- ) -> Result < FlightSqlServiceClient < Channel > > { // datafusion::error::Result
120
+ ) -> Result < FlightSqlServiceClient < Channel > > {
121
+ // datafusion::error::Result
119
122
let endpoint = Endpoint :: from_shared ( dsn. clone ( ) )
120
123
. map_err ( |e| DataFusionError :: External ( format ! ( "Invalid DSN {}: {}" , dsn, e) . into ( ) ) ) ?
121
124
. connect_timeout ( std:: time:: Duration :: from_secs ( 10 ) ) ;
122
125
123
- let channel = endpoint. connect ( ) . await
124
- . map_err ( |e| DataFusionError :: External ( format ! ( "Failed to connect to {}: {}" , dsn, e) . into ( ) ) ) ?;
125
-
126
- let service_client = FlightSqlServiceClient :: with_interceptor (
127
- channel,
128
- move |mut req : Request < ( ) > | {
129
- if let Some ( token_str) = token. clone ( ) {
130
- let bearer_token = format ! ( "Bearer {}" , token_str) ;
131
- match MetadataValue :: try_from ( & bearer_token) {
132
- Ok ( metadata_val) => req. metadata_mut ( ) . insert ( "authorization" , metadata_val) ,
133
- Err ( _) => return Err ( Status :: invalid_argument ( "Invalid token format for metadata" ) ) ,
134
- } ;
135
- }
136
- Ok ( req)
137
- } ,
138
- ) ;
139
- Ok ( service_client)
140
- }
126
+ let channel = endpoint. connect ( ) . await . map_err ( |e| {
127
+ DataFusionError :: External ( format ! ( "Failed to connect to {}: {}" , dsn, e) . into ( ) )
128
+ } ) ?;
141
129
142
- fn status_to_df_error ( err : tonic:: Status ) -> DataFusionError {
143
- DataFusionError :: External ( format ! ( "Tonic status error: {}" , err) . into ( ) )
130
+ let mut service_client = FlightSqlServiceClient :: new ( channel) ;
131
+ if let Some ( token_str) = token. clone ( ) {
132
+ service_client. set_header ( "authorization" , format ! ( "Bearer {}" , token_str) ) ;
133
+ }
134
+ Ok ( service_client)
144
135
}
145
136
146
137
#[ tokio:: main]
147
- async fn main ( ) -> Result < ( ) > { // datafusion::error::Result<()>
138
+ async fn main ( ) -> Result < ( ) > {
139
+ // datafusion::error::Result<()>
148
140
// Server Setup
149
141
let dsn: String = "0.0.0.0:50051" . to_string ( ) ;
150
- let state_provider = Arc :: new ( MySessionStateProvider :: try_new ( ) . await ?) ;
151
- let base_service = FlightSqlService :: new_with_state_provider ( state_provider) . into_service ( ) ;
152
- let wrapped_service = tonic:: service:: interceptor_fn ( base_service, bearer_auth_interceptor) ;
153
- let addr: std:: net:: SocketAddr = dsn. parse ( ) . map_err ( |e| DataFusionError :: External ( format ! ( "Invalid address format {}: {}" , dsn, e) . into ( ) ) ) ?;
142
+ let state_provider = Box :: new ( MySessionStateProvider :: try_new ( ) . await ?) ;
143
+ let base_service = FlightSqlService :: new_with_provider ( state_provider) ;
144
+ let svc: FlightServiceServer < FlightSqlService > = FlightServiceServer :: new ( base_service) ;
145
+ let addr: std:: net:: SocketAddr = dsn. parse ( ) . map_err ( |e| {
146
+ DataFusionError :: External ( format ! ( "Invalid address format {}: {}" , dsn, e) . into ( ) )
147
+ } ) ?;
154
148
155
149
tokio:: spawn ( async move {
156
- println ! ( "Bearer Authentication Flight SQL server listening on {}" , addr) ;
157
- if let Err ( e) = Server :: builder ( ) . add_service ( wrapped_service) . serve ( addr) . await {
150
+ println ! (
151
+ "Bearer Authentication Flight SQL server listening on {}" ,
152
+ addr
153
+ ) ;
154
+ if let Err ( e) = Server :: builder ( )
155
+ . layer ( tonic_async_interceptor:: async_interceptor (
156
+ bearer_auth_interceptor,
157
+ ) )
158
+ . add_service ( svc)
159
+ . serve ( addr)
160
+ . await
161
+ {
158
162
eprintln ! ( "Server error: {}" , e) ;
159
163
}
160
164
} ) ;
@@ -169,10 +173,18 @@ async fn main() -> Result<()> { // datafusion::error::Result<()>
169
173
println ! ( "\n Attempting GetTables with valid token (token1)..." ) ;
170
174
match new_client_with_auth ( client_dsn. clone ( ) , Some ( "token1" . to_string ( ) ) ) . await {
171
175
Ok ( mut client) => {
172
- let request = CommandGetTables { catalog : None , db_schema_filter_pattern : None , table_name_filter_pattern : None , table_types : vec ! [ ] , include_schema : false } ;
176
+ let request = CommandGetTables {
177
+ catalog : None ,
178
+ db_schema_filter_pattern : None ,
179
+ table_name_filter_pattern : None ,
180
+ table_types : vec ! [ ] ,
181
+ include_schema : false ,
182
+ } ;
173
183
match client. get_tables ( request) . await {
174
- Ok ( response) => println ! ( "GetTables with token1 SUCCEEDED. Response: {:?}" , response. into_inner( ) ) ,
175
- Err ( e) => eprintln ! ( "GetTables with token1 FAILED: {}" , status_to_df_error( e) ) ,
184
+ Ok ( response) => {
185
+ println ! ( "GetTables with token1 SUCCEEDED. Response: {:?}" , response)
186
+ }
187
+ Err ( e) => eprintln ! ( "GetTables with token1 FAILED: {}" , e) ,
176
188
}
177
189
}
178
190
Err ( e) => eprintln ! ( "Failed to create client with token1: {}" , e) ,
@@ -182,10 +194,19 @@ async fn main() -> Result<()> { // datafusion::error::Result<()>
182
194
println ! ( "\n Attempting GetTables with invalid token (invalidtoken)..." ) ;
183
195
match new_client_with_auth ( client_dsn. clone ( ) , Some ( "invalidtoken" . to_string ( ) ) ) . await {
184
196
Ok ( mut client) => {
185
- let request = CommandGetTables { catalog : None , db_schema_filter_pattern : None , table_name_filter_pattern : None , table_types : vec ! [ ] , include_schema : false } ;
197
+ let request = CommandGetTables {
198
+ catalog : None ,
199
+ db_schema_filter_pattern : None ,
200
+ table_name_filter_pattern : None ,
201
+ table_types : vec ! [ ] ,
202
+ include_schema : false ,
203
+ } ;
186
204
match client. get_tables ( request) . await {
187
- Ok ( response) => println ! ( "GetTables with invalidtoken SUCCEEDED (unexpected). Response: {:?}" , response. into_inner( ) ) ,
188
- Err ( e) => eprintln ! ( "GetTables with invalidtoken FAILED (as expected): {}" , status_to_df_error( e) ) ,
205
+ Ok ( response) => println ! (
206
+ "GetTables with invalidtoken SUCCEEDED (unexpected). Response: {:?}" ,
207
+ response
208
+ ) ,
209
+ Err ( e) => eprintln ! ( "GetTables with invalidtoken FAILED (as expected): {:?}" , e) ,
189
210
}
190
211
}
191
212
Err ( e) => eprintln ! ( "Failed to create client with invalidtoken: {}" , e) ,
@@ -195,10 +216,19 @@ async fn main() -> Result<()> { // datafusion::error::Result<()>
195
216
println ! ( "\n Attempting GetTables with no token..." ) ;
196
217
match new_client_with_auth ( client_dsn. clone ( ) , None ) . await {
197
218
Ok ( mut client) => {
198
- let request = CommandGetTables { catalog : None , db_schema_filter_pattern : None , table_name_filter_pattern : None , table_types : vec ! [ ] , include_schema : false } ;
219
+ let request = CommandGetTables {
220
+ catalog : None ,
221
+ db_schema_filter_pattern : None ,
222
+ table_name_filter_pattern : None ,
223
+ table_types : vec ! [ ] ,
224
+ include_schema : false ,
225
+ } ;
199
226
match client. get_tables ( request) . await {
200
- Ok ( response) => println ! ( "GetTables with no token SUCCEEDED (unexpected). Response: {:?}" , response. into_inner( ) ) ,
201
- Err ( e) => eprintln ! ( "GetTables with no token FAILED (as expected): {}" , status_to_df_error( e) ) ,
227
+ Ok ( response) => println ! (
228
+ "GetTables with no token SUCCEEDED (unexpected). Response: {:?}" ,
229
+ response
230
+ ) ,
231
+ Err ( e) => eprintln ! ( "GetTables with no token FAILED (as expected): {:?}" , e) ,
202
232
}
203
233
}
204
234
Err ( e) => eprintln ! ( "Failed to create client with no token: {}" , e) ,
0 commit comments