@@ -357,12 +357,25 @@ impl WebRtcDialerState {
357357 & mut self ,
358358 payload : Vec < u8 > ,
359359 ) -> Result < HandshakeResult , crate :: error:: NegotiationError > {
360- let Message :: Protocols ( protocols) =
361- Message :: decode ( payload. into ( ) ) . map_err ( |_| ParseError :: InvalidData ) ?
362- else {
363- return Err ( crate :: error:: NegotiationError :: MultistreamSelectError (
364- NegotiationError :: Failed ,
365- ) ) ;
360+
361+ let protocols = match Message :: decode ( payload. into ( ) ) {
362+ Ok ( Message :: Header ( HeaderLine :: V1 ) ) => {
363+ self . state = HandshakeState :: WaitingProtocol ;
364+ return Ok ( HandshakeResult :: NotReady ) ;
365+ }
366+ Ok ( Message :: Protocol ( protocol) ) => vec ! [ protocol] ,
367+ Ok ( Message :: Protocols ( protocols) ) => protocols,
368+ Ok ( Message :: NotAvailable ) => {
369+ return match & self . state {
370+ HandshakeState :: WaitingProtocol =>
371+ Err ( error:: NegotiationError :: MultistreamSelectError (
372+ NegotiationError :: Failed
373+ ) ) ,
374+ _ => Err ( error:: NegotiationError :: StateMismatch ) ,
375+ }
376+ }
377+ Ok ( Message :: ListProtocols ) => return Err ( error:: NegotiationError :: StateMismatch ) ,
378+ Err ( _) => return Err ( error:: NegotiationError :: ParseError ( ParseError :: InvalidData ) ) ,
366379 } ;
367380
368381 let mut protocol_iter = protocols. into_iter ( ) ;
@@ -410,7 +423,9 @@ mod tests {
410423 use super :: * ;
411424 use crate :: multistream_select:: listener_select_proto;
412425 use std:: time:: Duration ;
426+ use bytes:: BufMut ;
413427 use tokio:: net:: { TcpListener , TcpStream } ;
428+ use crate :: multistream_select:: protocol:: MSG_MULTISTREAM_1_0 ;
414429
415430 #[ tokio:: test]
416431 async fn select_proto_basic ( ) {
@@ -805,17 +820,19 @@ mod tests {
805820 }
806821
807822 #[ test]
808- fn register_response_invalid_message ( ) {
809- // send only header line
823+ fn register_response_header_only ( ) {
810824 let mut bytes = BytesMut :: with_capacity ( 32 ) ;
825+ bytes. put_u8 ( MSG_MULTISTREAM_1_0 . len ( ) as u8 ) ;
826+
811827 let message = Message :: Header ( HeaderLine :: V1 ) ;
812828 message. encode ( & mut bytes) . map_err ( |_| Error :: InvalidData ) . unwrap ( ) ;
813829
814830 let ( mut dialer_state, _message) =
815831 WebRtcDialerState :: propose ( ProtocolName :: from ( "/13371338/proto/1" ) , vec ! [ ] ) . unwrap ( ) ;
816832
817833 match dialer_state. register_response ( bytes. freeze ( ) . to_vec ( ) ) {
818- Err ( error:: NegotiationError :: MultistreamSelectError ( NegotiationError :: Failed ) ) => { }
834+ Ok ( HandshakeResult :: NotReady ) => { } ,
835+ Err ( err) => panic ! ( "unexpected error: {:?}" , err) ,
819836 event => panic ! ( "invalid event: {event:?}" ) ,
820837 }
821838 }
0 commit comments