@@ -11,7 +11,10 @@ use std::{
1111} ;
1212use tokio_stream:: { Stream , StreamExt } ;
1313
14+ use fuse:: Fuse ;
15+
1416pub ( super ) const BUFFER_SIZE : usize = 8 * 1024 ;
17+ const YIELD_THRESHOLD : usize = 32 * 1024 ;
1518
1619pub ( crate ) fn encode_server < T , U > (
1720 encoder : T ,
2427 T : Encoder < Error = Status > ,
2528 U : Stream < Item = Result < T :: Item , Status > > ,
2629{
27- let stream = encode (
30+ let stream = EncodedBytes :: new (
2831 encoder,
2932 source,
3033 compression_encoding,
4548 T : Encoder < Error = Status > ,
4649 U : Stream < Item = T :: Item > ,
4750{
48- let stream = encode (
51+ let stream = EncodedBytes :: new (
4952 encoder,
5053 source. map ( Ok ) ,
5154 compression_encoding,
@@ -55,44 +58,115 @@ where
5558 EncodeBody :: new_client ( stream)
5659}
5760
58- fn encode < T , U > (
59- mut encoder : T ,
60- source : U ,
61+ /// Combinator for efficient encoding of messages into reasonably sized buffers.
62+ /// EncodedBytes encodes ready messages from its delegate stream into a BytesMut,
63+ /// splitting off and yielding a buffer when either:
64+ /// * The delegate stream polls as not ready, or
65+ /// * The encoded buffer surpasses YIELD_THRESHOLD.
66+ #[ pin_project( project = EncodedBytesProj ) ]
67+ #[ derive( Debug ) ]
68+ pub ( crate ) struct EncodedBytes < T , U >
69+ where
70+ T : Encoder < Error = Status > ,
71+ U : Stream < Item = Result < T :: Item , Status > > ,
72+ {
73+ #[ pin]
74+ source : Fuse < U > ,
75+ encoder : T ,
6176 compression_encoding : Option < CompressionEncoding > ,
62- compression_override : SingleMessageCompressionOverride ,
6377 max_message_size : Option < usize > ,
64- ) -> impl Stream < Item = Result < Bytes , Status > >
78+ buf : BytesMut ,
79+ uncompression_buf : BytesMut ,
80+ }
81+
82+ impl < T , U > EncodedBytes < T , U >
6583where
6684 T : Encoder < Error = Status > ,
6785 U : Stream < Item = Result < T :: Item , Status > > ,
6886{
69- let mut buf = BytesMut :: with_capacity ( BUFFER_SIZE ) ;
87+ fn new (
88+ encoder : T ,
89+ source : U ,
90+ compression_encoding : Option < CompressionEncoding > ,
91+ compression_override : SingleMessageCompressionOverride ,
92+ max_message_size : Option < usize > ,
93+ ) -> Self {
94+ let buf = BytesMut :: with_capacity ( BUFFER_SIZE ) ;
7095
71- let compression_encoding = if compression_override == SingleMessageCompressionOverride :: Disable
72- {
73- None
74- } else {
75- compression_encoding
76- } ;
96+ let compression_encoding =
97+ if compression_override == SingleMessageCompressionOverride :: Disable {
98+ None
99+ } else {
100+ compression_encoding
101+ } ;
77102
78- let mut uncompression_buf = if compression_encoding. is_some ( ) {
79- BytesMut :: with_capacity ( BUFFER_SIZE )
80- } else {
81- BytesMut :: new ( )
82- } ;
103+ let uncompression_buf = if compression_encoding. is_some ( ) {
104+ BytesMut :: with_capacity ( BUFFER_SIZE )
105+ } else {
106+ BytesMut :: new ( )
107+ } ;
83108
84- source. map ( move |result| {
85- let item = result?;
109+ return EncodedBytes {
110+ source : Fuse :: new ( source) ,
111+ encoder,
112+ compression_encoding,
113+ max_message_size,
114+ buf,
115+ uncompression_buf,
116+ } ;
117+ }
118+ }
86119
87- encode_item (
88- & mut encoder,
89- & mut buf,
90- & mut uncompression_buf,
120+ impl < T , U > Stream for EncodedBytes < T , U >
121+ where
122+ T : Encoder < Error = Status > ,
123+ U : Stream < Item = Result < T :: Item , Status > > ,
124+ {
125+ type Item = Result < Bytes , Status > ;
126+
127+ fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
128+ let EncodedBytesProj {
129+ mut source,
130+ encoder,
91131 compression_encoding,
92132 max_message_size,
93- item,
94- )
95- } )
133+ buf,
134+ uncompression_buf,
135+ } = self . project ( ) ;
136+
137+ loop {
138+ match source. as_mut ( ) . poll_next ( cx) {
139+ Poll :: Pending if buf. is_empty ( ) => {
140+ return Poll :: Pending ;
141+ }
142+ Poll :: Ready ( None ) if buf. is_empty ( ) => {
143+ return Poll :: Ready ( None ) ;
144+ }
145+ Poll :: Pending | Poll :: Ready ( None ) => {
146+ return Poll :: Ready ( Some ( Ok ( buf. split_to ( buf. len ( ) ) . freeze ( ) ) ) ) ;
147+ }
148+ Poll :: Ready ( Some ( Ok ( item) ) ) => {
149+ if let Err ( status) = encode_item (
150+ encoder,
151+ buf,
152+ uncompression_buf,
153+ * compression_encoding,
154+ * max_message_size,
155+ item,
156+ ) {
157+ return Poll :: Ready ( Some ( Err ( status) ) ) ;
158+ }
159+
160+ if buf. len ( ) >= YIELD_THRESHOLD {
161+ return Poll :: Ready ( Some ( Ok ( buf. split_to ( buf. len ( ) ) . freeze ( ) ) ) ) ;
162+ }
163+ }
164+ Poll :: Ready ( Some ( Err ( status) ) ) => {
165+ return Poll :: Ready ( Some ( Err ( status) ) ) ;
166+ }
167+ }
168+ }
169+ }
96170}
97171
98172fn encode_item < T > (
@@ -102,10 +176,12 @@ fn encode_item<T>(
102176 compression_encoding : Option < CompressionEncoding > ,
103177 max_message_size : Option < usize > ,
104178 item : T :: Item ,
105- ) -> Result < Bytes , Status >
179+ ) -> Result < ( ) , Status >
106180where
107181 T : Encoder < Error = Status > ,
108182{
183+ let offset = buf. len ( ) ;
184+
109185 buf. reserve ( HEADER_SIZE ) ;
110186 unsafe {
111187 buf. advance_mut ( HEADER_SIZE ) ;
@@ -129,14 +205,14 @@ where
129205 }
130206
131207 // now that we know length, we can write the header
132- finish_encoding ( compression_encoding, max_message_size, buf)
208+ finish_encoding ( compression_encoding, max_message_size, & mut buf[ offset.. ] )
133209}
134210
135211fn finish_encoding (
136212 compression_encoding : Option < CompressionEncoding > ,
137213 max_message_size : Option < usize > ,
138- buf : & mut BytesMut ,
139- ) -> Result < Bytes , Status > {
214+ buf : & mut [ u8 ] ,
215+ ) -> Result < ( ) , Status > {
140216 let len = buf. len ( ) - HEADER_SIZE ;
141217 let limit = max_message_size. unwrap_or ( DEFAULT_MAX_SEND_MESSAGE_SIZE ) ;
142218 if len > limit {
@@ -160,7 +236,7 @@ fn finish_encoding(
160236 buf. put_u32 ( len as u32 ) ;
161237 }
162238
163- Ok ( buf . split_to ( len + HEADER_SIZE ) . freeze ( ) )
239+ Ok ( ( ) )
164240}
165241
166242#[ derive( Debug ) ]
@@ -269,3 +345,57 @@ where
269345 Poll :: Ready ( self . project ( ) . state . trailers ( ) )
270346 }
271347}
348+
349+ mod fuse {
350+ use std:: {
351+ pin:: Pin ,
352+ task:: { ready, Context , Poll } ,
353+ } ;
354+
355+ use tokio_stream:: Stream ;
356+
357+ /// Stream for the [`fuse`](super::StreamExt::fuse) method.
358+ #[ derive( Debug ) ]
359+ #[ pin_project:: pin_project]
360+ #[ must_use = "streams do nothing unless polled" ]
361+ pub ( crate ) struct Fuse < St > {
362+ #[ pin]
363+ stream : St ,
364+ done : bool ,
365+ }
366+
367+ impl < St > Fuse < St > {
368+ pub ( crate ) fn new ( stream : St ) -> Self {
369+ Self {
370+ stream,
371+ done : false ,
372+ }
373+ }
374+ }
375+
376+ impl < S : Stream > Stream for Fuse < S > {
377+ type Item = S :: Item ;
378+
379+ fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < S :: Item > > {
380+ let this = self . project ( ) ;
381+
382+ if * this. done {
383+ return Poll :: Ready ( None ) ;
384+ }
385+
386+ let item = ready ! ( this. stream. poll_next( cx) ) ;
387+ if item. is_none ( ) {
388+ * this. done = true ;
389+ }
390+ Poll :: Ready ( item)
391+ }
392+
393+ fn size_hint ( & self ) -> ( usize , Option < usize > ) {
394+ if self . done {
395+ ( 0 , Some ( 0 ) )
396+ } else {
397+ self . stream . size_hint ( )
398+ }
399+ }
400+ }
401+ }
0 commit comments