Skip to content

Commit 0be0f26

Browse files
committed
Merge pull request #527 from hyperium/server-keep-alive
feat(server): check Response headers for Connection: close in keep_alive loop
2 parents 871f37a + 49b5b8f commit 0be0f26

File tree

2 files changed

+94
-56
lines changed

2 files changed

+94
-56
lines changed

src/server/mod.rs

Lines changed: 78 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pub use net::{Fresh, Streaming};
3333

3434
use Error;
3535
use buffer::BufReader;
36-
use header::{Headers, Expect};
36+
use header::{Headers, Expect, Connection};
3737
use http;
3838
use method::Method;
3939
use net::{NetworkListener, NetworkStream, HttpListener};
@@ -142,7 +142,7 @@ L: NetworkListener + Send + 'static {
142142

143143
debug!("threads = {:?}", threads);
144144
let pool = ListenerPool::new(listener.clone());
145-
let work = move |mut stream| handle_connection(&mut stream, &handler);
145+
let work = move |mut stream| Worker(&handler).handle_connection(&mut stream);
146146

147147
let guard = thread::spawn(move || pool.accept(work, threads));
148148

@@ -152,62 +152,95 @@ L: NetworkListener + Send + 'static {
152152
})
153153
}
154154

155-
fn handle_connection<'h, S, H>(mut stream: &mut S, handler: &'h H)
156-
where S: NetworkStream + Clone, H: Handler {
157-
debug!("Incoming stream");
158-
let addr = match stream.peer_addr() {
159-
Ok(addr) => addr,
160-
Err(e) => {
161-
error!("Peer Name error: {:?}", e);
162-
return;
163-
}
164-
};
165-
166-
// FIXME: Use Type ascription
167-
let stream_clone: &mut NetworkStream = &mut stream.clone();
168-
let mut rdr = BufReader::new(stream_clone);
169-
let mut wrt = BufWriter::new(stream);
170-
171-
let mut keep_alive = true;
172-
while keep_alive {
173-
let req = match Request::new(&mut rdr, addr) {
174-
Ok(req) => req,
175-
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
176-
trace!("tcp closed, cancelling keep-alive loop");
177-
break;
155+
struct Worker<'a, H: Handler + 'static>(&'a H);
156+
157+
impl<'a, H: Handler + 'static> Worker<'a, H> {
158+
159+
fn handle_connection<S>(&self, mut stream: &mut S) where S: NetworkStream + Clone {
160+
debug!("Incoming stream");
161+
let addr = match stream.peer_addr() {
162+
Ok(addr) => addr,
163+
Err(e) => {
164+
error!("Peer Name error: {:?}", e);
165+
return;
178166
}
179-
Err(Error::Io(e)) => {
180-
debug!("ioerror in keepalive loop = {:?}", e);
167+
};
168+
169+
// FIXME: Use Type ascription
170+
let stream_clone: &mut NetworkStream = &mut stream.clone();
171+
let rdr = BufReader::new(stream_clone);
172+
let wrt = BufWriter::new(stream);
173+
174+
self.keep_alive_loop(rdr, wrt, addr);
175+
debug!("keep_alive loop ending for {}", addr);
176+
}
177+
178+
fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>, mut wrt: W, addr: SocketAddr) {
179+
let mut keep_alive = true;
180+
while keep_alive {
181+
let req = match Request::new(&mut rdr, addr) {
182+
Ok(req) => req,
183+
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
184+
trace!("tcp closed, cancelling keep-alive loop");
185+
break;
186+
}
187+
Err(Error::Io(e)) => {
188+
debug!("ioerror in keepalive loop = {:?}", e);
189+
break;
190+
}
191+
Err(e) => {
192+
//TODO: send a 400 response
193+
error!("request error = {:?}", e);
194+
break;
195+
}
196+
};
197+
198+
199+
if !self.handle_expect(&req, &mut wrt) {
181200
break;
182201
}
183-
Err(e) => {
184-
//TODO: send a 400 response
185-
error!("request error = {:?}", e);
186-
break;
202+
203+
keep_alive = http::should_keep_alive(req.version, &req.headers);
204+
let version = req.version;
205+
let mut res_headers = Headers::new();
206+
if !keep_alive {
207+
res_headers.set(Connection::close());
208+
}
209+
{
210+
let mut res = Response::new(&mut wrt, &mut res_headers);
211+
res.version = version;
212+
self.0.handle(req, res);
187213
}
188-
};
189214

190-
if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
191-
let status = handler.check_continue((&req.method, &req.uri, &req.headers));
192-
match write!(&mut wrt, "{} {}\r\n\r\n", Http11, status) {
215+
// if the request was keep-alive, we need to check that the server agrees
216+
// if it wasn't, then the server cannot force it to be true anyways
217+
if keep_alive {
218+
keep_alive = http::should_keep_alive(version, &res_headers);
219+
}
220+
221+
debug!("keep_alive = {:?} for {}", keep_alive, addr);
222+
}
223+
224+
}
225+
226+
fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {
227+
if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
228+
let status = self.0.check_continue((&req.method, &req.uri, &req.headers));
229+
match write!(wrt, "{} {}\r\n\r\n", Http11, status) {
193230
Ok(..) => (),
194231
Err(e) => {
195232
error!("error writing 100-continue: {:?}", e);
196-
break;
233+
return false;
197234
}
198235
}
199236

200237
if status != StatusCode::Continue {
201238
debug!("non-100 status ({}) for Expect 100 request", status);
202-
break;
239+
return false;
203240
}
204241
}
205242

206-
keep_alive = http::should_keep_alive(req.version, &req.headers);
207-
let mut res = Response::new(&mut wrt);
208-
res.version = req.version;
209-
handler.handle(req, res);
210-
debug!("keep_alive = {:?}", keep_alive);
243+
true
211244
}
212245
}
213246

@@ -270,7 +303,7 @@ mod tests {
270303
use status::StatusCode;
271304
use uri::RequestUri;
272305

273-
use super::{Request, Response, Fresh, Handler, handle_connection};
306+
use super::{Request, Response, Fresh, Handler, Worker};
274307

275308
#[test]
276309
fn test_check_continue_default() {
@@ -287,7 +320,7 @@ mod tests {
287320
res.start().unwrap().end().unwrap();
288321
}
289322

290-
handle_connection(&mut mock, &handle);
323+
Worker(&handle).handle_connection(&mut mock);
291324
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
292325
assert_eq!(&mock.write[..cont.len()], cont);
293326
let res = b"HTTP/1.1 200 OK\r\n";
@@ -316,7 +349,7 @@ mod tests {
316349
1234567890\
317350
");
318351

319-
handle_connection(&mut mock, &Reject);
352+
Worker(&Reject).handle_connection(&mut mock);
320353
assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]);
321354
}
322355
}

src/server/response.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pub struct Response<'a, W: Any = Fresh> {
2828
// The status code for the request.
2929
status: status::StatusCode,
3030
// The outgoing headers on this response.
31-
headers: header::Headers,
31+
headers: &'a mut header::Headers,
3232

3333
_writing: PhantomData<W>
3434
}
@@ -39,13 +39,13 @@ impl<'a, W: Any> Response<'a, W> {
3939
pub fn status(&self) -> status::StatusCode { self.status }
4040

4141
/// The headers of this response.
42-
pub fn headers(&self) -> &header::Headers { &self.headers }
42+
pub fn headers(&self) -> &header::Headers { &*self.headers }
4343

4444
/// Construct a Response from its constituent parts.
4545
pub fn construct(version: version::HttpVersion,
4646
body: HttpWriter<&'a mut (Write + 'a)>,
4747
status: status::StatusCode,
48-
headers: header::Headers) -> Response<'a, Fresh> {
48+
headers: &'a mut header::Headers) -> Response<'a, Fresh> {
4949
Response {
5050
status: status,
5151
version: version,
@@ -57,7 +57,7 @@ impl<'a, W: Any> Response<'a, W> {
5757

5858
/// Deconstruct this Response into its constituent parts.
5959
pub fn deconstruct(self) -> (version::HttpVersion, HttpWriter<&'a mut (Write + 'a)>,
60-
status::StatusCode, header::Headers) {
60+
status::StatusCode, &'a mut header::Headers) {
6161
unsafe {
6262
let parts = (
6363
self.version,
@@ -114,11 +114,11 @@ impl<'a, W: Any> Response<'a, W> {
114114
impl<'a> Response<'a, Fresh> {
115115
/// Creates a new Response that can be used to write to a network stream.
116116
#[inline]
117-
pub fn new(stream: &'a mut (Write + 'a)) -> Response<'a, Fresh> {
117+
pub fn new(stream: &'a mut (Write + 'a), headers: &'a mut header::Headers) -> Response<'a, Fresh> {
118118
Response {
119119
status: status::StatusCode::Ok,
120120
version: version::HttpVersion::Http11,
121-
headers: header::Headers::new(),
121+
headers: headers,
122122
body: ThroughWriter(stream),
123123
_writing: PhantomData,
124124
}
@@ -165,7 +165,7 @@ impl<'a> Response<'a, Fresh> {
165165

166166
/// Get a mutable reference to the Headers.
167167
#[inline]
168-
pub fn headers_mut(&mut self) -> &mut header::Headers { &mut self.headers }
168+
pub fn headers_mut(&mut self) -> &mut header::Headers { self.headers }
169169
}
170170

171171

@@ -231,6 +231,7 @@ impl<'a, T: Any> Drop for Response<'a, T> {
231231

232232
#[cfg(test)]
233233
mod tests {
234+
use header::Headers;
234235
use mock::MockStream;
235236
use super::Response;
236237

@@ -252,9 +253,10 @@ mod tests {
252253

253254
#[test]
254255
fn test_fresh_start() {
256+
let mut headers = Headers::new();
255257
let mut stream = MockStream::new();
256258
{
257-
let res = Response::new(&mut stream);
259+
let res = Response::new(&mut stream, &mut headers);
258260
res.start().unwrap().deconstruct();
259261
}
260262

@@ -268,9 +270,10 @@ mod tests {
268270

269271
#[test]
270272
fn test_streaming_end() {
273+
let mut headers = Headers::new();
271274
let mut stream = MockStream::new();
272275
{
273-
let res = Response::new(&mut stream);
276+
let res = Response::new(&mut stream, &mut headers);
274277
res.start().unwrap().end().unwrap();
275278
}
276279

@@ -287,9 +290,10 @@ mod tests {
287290
#[test]
288291
fn test_fresh_drop() {
289292
use status::StatusCode;
293+
let mut headers = Headers::new();
290294
let mut stream = MockStream::new();
291295
{
292-
let mut res = Response::new(&mut stream);
296+
let mut res = Response::new(&mut stream, &mut headers);
293297
*res.status_mut() = StatusCode::NotFound;
294298
}
295299

@@ -307,9 +311,10 @@ mod tests {
307311
fn test_streaming_drop() {
308312
use std::io::Write;
309313
use status::StatusCode;
314+
let mut headers = Headers::new();
310315
let mut stream = MockStream::new();
311316
{
312-
let mut res = Response::new(&mut stream);
317+
let mut res = Response::new(&mut stream, &mut headers);
313318
*res.status_mut() = StatusCode::NotFound;
314319
let mut stream = res.start().unwrap();
315320
stream.write_all(b"foo").unwrap();

0 commit comments

Comments
 (0)