@@ -6,7 +6,8 @@ use std::io::{Read, Write};
6
6
use std:: ops:: { Deref , DerefMut } ;
7
7
use std:: str;
8
8
use std:: sync:: mpsc;
9
- use std:: sync:: Arc ;
9
+
10
+ #[ cfg( feature = "rsasl" ) ]
10
11
use rsasl:: prelude:: { Mechname , SASLClient , SASLConfig , Session as SASLSession , State as SASLState } ;
11
12
12
13
use super :: authenticator:: Authenticator ;
@@ -367,7 +368,7 @@ impl<T: Read + Write> Client<T> {
367
368
/// match client.login("user", "pass") {
368
369
/// Ok(s) => {
369
370
/// // you are successfully authenticated!
370
- /// },
371
+ /// }
371
372
/// Err((e, orig_client)) => {
372
373
/// eprintln!("error logging in: {}", e);
373
374
/// // prompt user and try again with orig_client here
@@ -425,7 +426,7 @@ impl<T: Read + Write> Client<T> {
425
426
/// match client.authenticate("XOAUTH2", &auth) {
426
427
/// Ok(session) => {
427
428
/// // you are successfully authenticated!
428
- /// },
429
+ /// }
429
430
/// Err((e, orig_client)) => {
430
431
/// eprintln!("error authenticating: {}", e);
431
432
/// // prompt user and try again with orig_client here
@@ -434,9 +435,82 @@ impl<T: Read + Write> Client<T> {
434
435
/// };
435
436
/// }
436
437
/// ```
437
- pub fn authenticate (
438
+ pub fn authenticate < A : Authenticator > (
438
439
mut self ,
439
- config : Arc < SASLConfig > ,
440
+ auth_type : impl AsRef < str > ,
441
+ authenticator : & A ,
442
+ ) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
443
+ ok_or_unauth_client_err ! (
444
+ self . run_command( & format!( "AUTHENTICATE {}" , auth_type. as_ref( ) ) ) ,
445
+ self
446
+ ) ;
447
+ self . do_auth_handshake ( authenticator)
448
+ }
449
+
450
+ /// This func does the handshake process once the authenticate command is made.
451
+ fn do_auth_handshake < A : Authenticator > (
452
+ mut self ,
453
+ authenticator : & A ,
454
+ ) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
455
+ // TODO Clean up this code
456
+ loop {
457
+ let mut line = Vec :: new ( ) ;
458
+
459
+ // explicit match blocks neccessary to convert error to tuple and not bind self too
460
+ // early (see also comment on `login`)
461
+ ok_or_unauth_client_err ! ( self . readline( & mut line) , self ) ;
462
+
463
+ // ignore server comments
464
+ if line. starts_with ( b"* " ) {
465
+ continue ;
466
+ }
467
+
468
+ // Some servers will only send `+\r\n`.
469
+ if line. starts_with ( b"+ " ) || & line == b"+\r \n " {
470
+ let challenge = if & line == b"+\r \n " {
471
+ Vec :: new ( )
472
+ } else {
473
+ let line_str = ok_or_unauth_client_err ! (
474
+ match str :: from_utf8( line. as_slice( ) ) {
475
+ Ok ( line_str) => Ok ( line_str) ,
476
+ Err ( e) => Err ( Error :: Parse ( ParseError :: DataNotUtf8 ( line, e) ) ) ,
477
+ } ,
478
+ self
479
+ ) ;
480
+ let data =
481
+ ok_or_unauth_client_err ! ( parse_authenticate_response( line_str) , self ) ;
482
+ ok_or_unauth_client_err ! (
483
+ base64:: decode( data) . map_err( |e| Error :: Parse ( ParseError :: Authentication (
484
+ data. to_string( ) ,
485
+ Some ( e)
486
+ ) ) ) ,
487
+ self
488
+ )
489
+ } ;
490
+
491
+ let raw_response = & authenticator. process ( & challenge) ;
492
+ let auth_response = base64:: encode ( raw_response) ;
493
+ ok_or_unauth_client_err ! (
494
+ self . write_line( auth_response. into_bytes( ) . as_slice( ) ) ,
495
+ self
496
+ ) ;
497
+ } else {
498
+ ok_or_unauth_client_err ! ( self . read_response_onto( & mut line) , self ) ;
499
+ return Ok ( Session :: new ( self . conn ) ) ;
500
+ }
501
+ }
502
+ }
503
+ }
504
+
505
+ #[ cfg( feature = "rsasl" ) ]
506
+ impl < T : Read + Write > Client < T > {
507
+
508
+ /// Authenticate with the server using the given custom SASLConfig to handle the server's
509
+ /// challenge.
510
+ ///
511
+ pub fn sasl_auth (
512
+ mut self ,
513
+ config : :: std:: sync:: Arc < SASLConfig > ,
440
514
mechanism : & Mechname ,
441
515
) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
442
516
let client = SASLClient :: new ( config) ;
@@ -450,11 +524,11 @@ impl<T: Read + Write> Client<T> {
450
524
self . run_command( & format!( "AUTHENTICATE {}" , mechanism. as_str( ) ) ) ,
451
525
self
452
526
) ;
453
- self . do_auth_handshake ( session)
527
+ self . do_sasl_handshake ( session)
454
528
}
455
529
456
- /// This func does the handshake process once the authenticate command is made.
457
- fn do_auth_handshake (
530
+ /// This func does the SASL handshake process once the authenticate command is made.
531
+ fn do_sasl_handshake (
458
532
mut self ,
459
533
mut authenticator : SASLSession ,
460
534
) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
0 commit comments