Skip to content

Commit

Permalink
Avoid reclaiming frames for dead streams. (#262)
Browse files Browse the repository at this point in the history
In `clear_queue` we drop all the queued frames for a stream, but this doesn't
take into account a buffered frame inside of the `FramedWrite`. This can lead
to a panic when `reclaim_frame` tries to recover a frame onto a stream that has
already been destroyed, or in general cause wrong behaviour.

Instead, let's keep track of what frame is currently in-flight; then, when we
`clear_queue` a stream with an in-flight data frame, mark the frame to be
dropped instead of reclaimed.
  • Loading branch information
goffrie authored and carllerche committed Apr 24, 2018
1 parent 11f9141 commit 558e6b6
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 5 deletions.
41 changes: 38 additions & 3 deletions src/proto/streams/prioritize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use codec::UserError::*;

use bytes::buf::Take;

use std::{cmp, fmt};
use std::{cmp, fmt, mem};
use std::io;

/// # Warning
Expand Down Expand Up @@ -48,6 +48,19 @@ pub(super) struct Prioritize {

/// Stream ID of the last stream opened.
last_opened_id: StreamId,

/// What `DATA` frame is currently being sent in the codec.
in_flight_data_frame: InFlightData,
}

#[derive(Debug, Eq, PartialEq)]
enum InFlightData {
/// There is no `DATA` frame in flight.
Nothing,
/// There is a `DATA` frame in flight belonging to the given stream.
DataFrame(store::Key),
/// There was a `DATA` frame, but the stream's queue was since cleared.
Drop,
}

pub(crate) struct Prioritized<B> {
Expand Down Expand Up @@ -79,7 +92,8 @@ impl Prioritize {
pending_capacity: store::Queue::new(),
pending_open: store::Queue::new(),
flow: flow,
last_opened_id: StreamId::ZERO
last_opened_id: StreamId::ZERO,
in_flight_data_frame: InFlightData::Nothing,
}
}

Expand Down Expand Up @@ -456,6 +470,10 @@ impl Prioritize {
Some(frame) => {
trace!("writing frame={:?}", frame);

debug_assert_eq!(self.in_flight_data_frame, InFlightData::Nothing);
if let Frame::Data(ref frame) = frame {
self.in_flight_data_frame = InFlightData::DataFrame(frame.payload().stream);
}
dst.buffer(frame).ok().expect("invalid frame");

// Ensure the codec is ready to try the loop again.
Expand Down Expand Up @@ -503,12 +521,23 @@ impl Prioritize {
trace!(
" -> reclaimed; frame={:?}; sz={}",
frame,
frame.payload().remaining()
frame.payload().inner.get_ref().remaining()
);

let mut eos = false;
let key = frame.payload().stream;

match mem::replace(&mut self.in_flight_data_frame, InFlightData::Nothing) {
InFlightData::Nothing => panic!("wasn't expecting a frame to reclaim"),
InFlightData::Drop => {
trace!("not reclaiming frame for cancelled stream");
return false;
}
InFlightData::DataFrame(k) => {
debug_assert_eq!(k, key);
}
}

let mut frame = frame.map(|prioritized| {
// TODO: Ensure fully written
eos = prioritized.end_of_stream;
Expand Down Expand Up @@ -558,6 +587,12 @@ impl Prioritize {

stream.buffered_send_data = 0;
stream.requested_send_capacity = 0;
if let InFlightData::DataFrame(key) = self.in_flight_data_frame {
if stream.key() == key {
// This stream could get cleaned up now - don't allow the buffered frame to get reclaimed.
self.in_flight_data_frame = InFlightData::Drop;
}
}
}

fn pop_frame<B>(
Expand Down
67 changes: 67 additions & 0 deletions tests/stream_states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,3 +878,70 @@ fn rst_while_closing() {

client.join(srv).wait().expect("wait");
}

#[test]
fn rst_with_buffered_data() {
use futures::future::lazy;

// Data is buffered in `FramedWrite` and the stream is reset locally before
// the data is fully flushed. Given that resetting a stream requires
// clearing all associated state for that stream, this test ensures that the
// buffered up frame is correctly handled.
let _ = ::env_logger::try_init();

// This allows the settings + headers frame through
let (io, srv) = mock::new_with_write_capacity(73);

// Synchronize the client / server on response
let (tx, rx) = ::futures::sync::oneshot::channel();

let srv = srv.assert_client_handshake()
.unwrap()
.recv_settings()
.recv_frame(
frames::headers(1)
.request("POST", "https://example.com/")
)
.buffer_bytes(128)
.send_frame(frames::headers(1).response(204).eos())
.send_frame(frames::reset(1).cancel())
.wait_for(rx)
.unbounded_bytes()
.recv_frame(
frames::data(1, vec![0; 16_384]))
.close()
;

// A large body
let body = vec![0; 2 * frame::DEFAULT_INITIAL_WINDOW_SIZE as usize];

let client = client::handshake(io)
.expect("handshake")
.and_then(|(mut client, conn)| {
let request = Request::builder()
.method(Method::POST)
.uri("https://example.com/")
.body(())
.unwrap();

// Send the request
let (resp, mut stream) = client.send_request(request, false)
.expect("send_request");

// Send the data
stream.send_data(body.into(), true).unwrap();

conn.drive({
resp.then(|res| {
Ok::<_, ()>(())
})
})
})
.and_then(move |(conn, _)| {
tx.send(()).unwrap();
conn.unwrap()
});


client.join(srv).wait().expect("wait");
}
100 changes: 98 additions & 2 deletions tests/support/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use futures::task::{self, Task};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_io::io::read_exact;

use std::{cmp, fmt, io};
use std::{cmp, fmt, io, usize};
use std::io::ErrorKind::WouldBlock;
use std::sync::{Arc, Mutex};

Expand All @@ -32,22 +32,44 @@ pub struct Pipe {

#[derive(Debug)]
struct Inner {
/// Data written by the test case to the h2 lib.
rx: Vec<u8>,

/// Notify when data is ready to be received.
rx_task: Option<Task>,

/// Data written by the `h2` library to be read by the test case.
tx: Vec<u8>,

/// Notify when data is written. This notifies the test case waiters.
tx_task: Option<Task>,

/// Number of bytes that can be written before `write` returns `NotReady`.
tx_rem: usize,

/// Task to notify when write capacity becomes available.
tx_rem_task: Option<Task>,

/// True when the pipe is closed.
closed: bool,
}

const PREFACE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";

/// Create a new mock and handle
pub fn new() -> (Mock, Handle) {
new_with_write_capacity(usize::MAX)
}

/// Create a new mock and handle allowing up to `cap` bytes to be written.
pub fn new_with_write_capacity(cap: usize) -> (Mock, Handle) {
let inner = Arc::new(Mutex::new(Inner {
rx: vec![],
rx_task: None,
tx: vec![],
tx_task: None,
tx_rem: cap,
tx_rem_task: None,
closed: false,
}));

Expand Down Expand Up @@ -303,14 +325,24 @@ impl io::Read for Mock {
impl AsyncRead for Mock {}

impl io::Write for Mock {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
let mut me = self.pipe.inner.lock().unwrap();

if me.closed {
return Err(io::Error::new(io::ErrorKind::BrokenPipe, "mock closed"));
}

if me.tx_rem == 0 {
me.tx_rem_task = Some(task::current());
return Err(io::ErrorKind::WouldBlock.into());
}

if buf.len() > me.tx_rem {
buf = &buf[..me.tx_rem];
}

me.tx.extend(buf);
me.tx_rem -= buf.len();

if let Some(task) = me.tx_task.take() {
task.notify();
Expand Down Expand Up @@ -477,6 +509,70 @@ pub trait HandleFutureExt {
}))
}

fn buffer_bytes(self, num: usize) -> Box<Future<Item = Handle, Error = Self::Error>>
where Self: Sized + 'static,
Self: Future<Item = Handle>,
Self::Error: fmt::Debug,
{
use futures::future::poll_fn;

Box::new(self.and_then(move |mut handle| {
// Set tx_rem to num
{
let mut i = handle.codec.get_mut().inner.lock().unwrap();
i.tx_rem = num;
}

let mut handle = Some(handle);

poll_fn(move || {
{
let mut inner = handle.as_mut().unwrap()
.codec.get_mut().inner.lock().unwrap();

if inner.tx_rem == 0 {
inner.tx_rem = usize::MAX;
} else {
inner.tx_task = Some(task::current());
return Ok(Async::NotReady);
}
}

Ok(handle.take().unwrap().into())
})
}))
}

fn unbounded_bytes(self) -> Box<Future<Item = Handle, Error = Self::Error>>
where Self: Sized + 'static,
Self: Future<Item = Handle>,
Self::Error: fmt::Debug,
{
Box::new(self.and_then(|mut handle| {
{
let mut i = handle.codec.get_mut().inner.lock().unwrap();
i.tx_rem = usize::MAX;

if let Some(task) = i.tx_rem_task.take() {
task.notify();
}
}

Ok(handle.into())
}))
}

fn then_notify(self, tx: oneshot::Sender<()>) -> Box<Future<Item = Handle, Error = Self::Error>>
where Self: Sized + 'static,
Self: Future<Item = Handle>,
Self::Error: fmt::Debug,
{
Box::new(self.map(move |handle| {
tx.send(()).unwrap();
handle
}))
}

fn wait_for<F>(self, other: F) -> Box<Future<Item = Self::Item, Error = Self::Error>>
where
F: Future + 'static,
Expand Down

0 comments on commit 558e6b6

Please sign in to comment.