diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index ef66ed63e..712a991a4 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -8,7 +8,7 @@ use codec::UserError::*; use bytes::buf::Take; -use std::{cmp, fmt}; +use std::{cmp, fmt, mem}; use std::io; /// # Warning @@ -48,6 +48,19 @@ pub(super) struct Prioritize { /// Stream ID of the last stream opened. last_opened_id: StreamId, + + /// What `DATA` frame is currently being sent in the codec. + in_flight_data_frame: InFlightData, +} + +#[derive(Debug, Eq, PartialEq)] +enum InFlightData { + /// There is no `DATA` frame in flight. + Nothing, + /// There is a `DATA` frame in flight belonging to the given stream. + DataFrame(store::Key), + /// There was a `DATA` frame, but the stream's queue was since cleared. + Drop, } pub(crate) struct Prioritized { @@ -79,7 +92,8 @@ impl Prioritize { pending_capacity: store::Queue::new(), pending_open: store::Queue::new(), flow: flow, - last_opened_id: StreamId::ZERO + last_opened_id: StreamId::ZERO, + in_flight_data_frame: InFlightData::Nothing, } } @@ -456,6 +470,10 @@ impl Prioritize { Some(frame) => { trace!("writing frame={:?}", frame); + debug_assert_eq!(self.in_flight_data_frame, InFlightData::Nothing); + if let Frame::Data(ref frame) = frame { + self.in_flight_data_frame = InFlightData::DataFrame(frame.payload().stream); + } dst.buffer(frame).ok().expect("invalid frame"); // Ensure the codec is ready to try the loop again. @@ -503,12 +521,23 @@ impl Prioritize { trace!( " -> reclaimed; frame={:?}; sz={}", frame, - frame.payload().remaining() + frame.payload().inner.get_ref().remaining() ); let mut eos = false; let key = frame.payload().stream; + match mem::replace(&mut self.in_flight_data_frame, InFlightData::Nothing) { + InFlightData::Nothing => panic!("wasn't expecting a frame to reclaim"), + InFlightData::Drop => { + trace!("not reclaiming frame for cancelled stream"); + return false; + } + InFlightData::DataFrame(k) => { + debug_assert_eq!(k, key); + } + } + let mut frame = frame.map(|prioritized| { // TODO: Ensure fully written eos = prioritized.end_of_stream; @@ -558,6 +587,12 @@ impl Prioritize { stream.buffered_send_data = 0; stream.requested_send_capacity = 0; + if let InFlightData::DataFrame(key) = self.in_flight_data_frame { + if stream.key() == key { + // This stream could get cleaned up now - don't allow the buffered frame to get reclaimed. + self.in_flight_data_frame = InFlightData::Drop; + } + } } fn pop_frame( diff --git a/tests/stream_states.rs b/tests/stream_states.rs index c328a2ac9..a1f2754fb 100644 --- a/tests/stream_states.rs +++ b/tests/stream_states.rs @@ -878,3 +878,70 @@ fn rst_while_closing() { client.join(srv).wait().expect("wait"); } + +#[test] +fn rst_with_buffered_data() { + use futures::future::lazy; + + // Data is buffered in `FramedWrite` and the stream is reset locally before + // the data is fully flushed. Given that resetting a stream requires + // clearing all associated state for that stream, this test ensures that the + // buffered up frame is correctly handled. + let _ = ::env_logger::try_init(); + + // This allows the settings + headers frame through + let (io, srv) = mock::new_with_write_capacity(73); + + // Synchronize the client / server on response + let (tx, rx) = ::futures::sync::oneshot::channel(); + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_settings() + .recv_frame( + frames::headers(1) + .request("POST", "https://example.com/") + ) + .buffer_bytes(128) + .send_frame(frames::headers(1).response(204).eos()) + .send_frame(frames::reset(1).cancel()) + .wait_for(rx) + .unbounded_bytes() + .recv_frame( + frames::data(1, vec![0; 16_384])) + .close() + ; + + // A large body + let body = vec![0; 2 * frame::DEFAULT_INITIAL_WINDOW_SIZE as usize]; + + let client = client::handshake(io) + .expect("handshake") + .and_then(|(mut client, conn)| { + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // Send the request + let (resp, mut stream) = client.send_request(request, false) + .expect("send_request"); + + // Send the data + stream.send_data(body.into(), true).unwrap(); + + conn.drive({ + resp.then(|res| { + Ok::<_, ()>(()) + }) + }) + }) + .and_then(move |(conn, _)| { + tx.send(()).unwrap(); + conn.unwrap() + }); + + + client.join(srv).wait().expect("wait"); +} diff --git a/tests/support/mock.rs b/tests/support/mock.rs index b510e89bf..761bd756a 100644 --- a/tests/support/mock.rs +++ b/tests/support/mock.rs @@ -10,7 +10,7 @@ use futures::task::{self, Task}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::io::read_exact; -use std::{cmp, fmt, io}; +use std::{cmp, fmt, io, usize}; use std::io::ErrorKind::WouldBlock; use std::sync::{Arc, Mutex}; @@ -32,10 +32,25 @@ pub struct Pipe { #[derive(Debug)] struct Inner { + /// Data written by the test case to the h2 lib. rx: Vec, + + /// Notify when data is ready to be received. rx_task: Option, + + /// Data written by the `h2` library to be read by the test case. tx: Vec, + + /// Notify when data is written. This notifies the test case waiters. tx_task: Option, + + /// Number of bytes that can be written before `write` returns `NotReady`. + tx_rem: usize, + + /// Task to notify when write capacity becomes available. + tx_rem_task: Option, + + /// True when the pipe is closed. closed: bool, } @@ -43,11 +58,18 @@ const PREFACE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; /// Create a new mock and handle pub fn new() -> (Mock, Handle) { + new_with_write_capacity(usize::MAX) +} + +/// Create a new mock and handle allowing up to `cap` bytes to be written. +pub fn new_with_write_capacity(cap: usize) -> (Mock, Handle) { let inner = Arc::new(Mutex::new(Inner { rx: vec![], rx_task: None, tx: vec![], tx_task: None, + tx_rem: cap, + tx_rem_task: None, closed: false, })); @@ -303,14 +325,24 @@ impl io::Read for Mock { impl AsyncRead for Mock {} impl io::Write for Mock { - fn write(&mut self, buf: &[u8]) -> io::Result { + fn write(&mut self, mut buf: &[u8]) -> io::Result { let mut me = self.pipe.inner.lock().unwrap(); if me.closed { return Err(io::Error::new(io::ErrorKind::BrokenPipe, "mock closed")); } + if me.tx_rem == 0 { + me.tx_rem_task = Some(task::current()); + return Err(io::ErrorKind::WouldBlock.into()); + } + + if buf.len() > me.tx_rem { + buf = &buf[..me.tx_rem]; + } + me.tx.extend(buf); + me.tx_rem -= buf.len(); if let Some(task) = me.tx_task.take() { task.notify(); @@ -477,6 +509,70 @@ pub trait HandleFutureExt { })) } + fn buffer_bytes(self, num: usize) -> Box> + where Self: Sized + 'static, + Self: Future, + Self::Error: fmt::Debug, + { + use futures::future::poll_fn; + + Box::new(self.and_then(move |mut handle| { + // Set tx_rem to num + { + let mut i = handle.codec.get_mut().inner.lock().unwrap(); + i.tx_rem = num; + } + + let mut handle = Some(handle); + + poll_fn(move || { + { + let mut inner = handle.as_mut().unwrap() + .codec.get_mut().inner.lock().unwrap(); + + if inner.tx_rem == 0 { + inner.tx_rem = usize::MAX; + } else { + inner.tx_task = Some(task::current()); + return Ok(Async::NotReady); + } + } + + Ok(handle.take().unwrap().into()) + }) + })) + } + + fn unbounded_bytes(self) -> Box> + where Self: Sized + 'static, + Self: Future, + Self::Error: fmt::Debug, + { + Box::new(self.and_then(|mut handle| { + { + let mut i = handle.codec.get_mut().inner.lock().unwrap(); + i.tx_rem = usize::MAX; + + if let Some(task) = i.tx_rem_task.take() { + task.notify(); + } + } + + Ok(handle.into()) + })) + } + + fn then_notify(self, tx: oneshot::Sender<()>) -> Box> + where Self: Sized + 'static, + Self: Future, + Self::Error: fmt::Debug, + { + Box::new(self.map(move |handle| { + tx.send(()).unwrap(); + handle + })) + } + fn wait_for(self, other: F) -> Box> where F: Future + 'static,