Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid reclaiming frames for dead streams. #262

Merged
merged 2 commits into from
Apr 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions src/proto/streams/prioritize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use codec::UserError::*;

use bytes::buf::Take;

use std::{cmp, fmt};
use std::{cmp, fmt, mem};
use std::io;

/// # Warning
Expand Down Expand Up @@ -48,6 +48,19 @@ pub(super) struct Prioritize {

/// Stream ID of the last stream opened.
last_opened_id: StreamId,

/// What `DATA` frame is currently being sent in the codec.
in_flight_data_frame: InFlightData,
}

#[derive(Debug, Eq, PartialEq)]
enum InFlightData {
/// There is no `DATA` frame in flight.
Nothing,
/// There is a `DATA` frame in flight belonging to the given stream.
DataFrame(store::Key),
/// There was a `DATA` frame, but the stream's queue was since cleared.
Drop,
}

pub(crate) struct Prioritized<B> {
Expand Down Expand Up @@ -79,7 +92,8 @@ impl Prioritize {
pending_capacity: store::Queue::new(),
pending_open: store::Queue::new(),
flow: flow,
last_opened_id: StreamId::ZERO
last_opened_id: StreamId::ZERO,
in_flight_data_frame: InFlightData::Nothing,
}
}

Expand Down Expand Up @@ -456,6 +470,10 @@ impl Prioritize {
Some(frame) => {
trace!("writing frame={:?}", frame);

debug_assert_eq!(self.in_flight_data_frame, InFlightData::Nothing);
if let Frame::Data(ref frame) = frame {
self.in_flight_data_frame = InFlightData::DataFrame(frame.payload().stream);
}
dst.buffer(frame).ok().expect("invalid frame");

// Ensure the codec is ready to try the loop again.
Expand Down Expand Up @@ -503,12 +521,23 @@ impl Prioritize {
trace!(
" -> reclaimed; frame={:?}; sz={}",
frame,
frame.payload().remaining()
frame.payload().inner.get_ref().remaining()
);

let mut eos = false;
let key = frame.payload().stream;

match mem::replace(&mut self.in_flight_data_frame, InFlightData::Nothing) {
InFlightData::Nothing => panic!("wasn't expecting a frame to reclaim"),
InFlightData::Drop => {
trace!("not reclaiming frame for cancelled stream");
return false;
}
InFlightData::DataFrame(k) => {
debug_assert_eq!(k, key);
}
}

let mut frame = frame.map(|prioritized| {
// TODO: Ensure fully written
eos = prioritized.end_of_stream;
Expand Down Expand Up @@ -558,6 +587,12 @@ impl Prioritize {

stream.buffered_send_data = 0;
stream.requested_send_capacity = 0;
if let InFlightData::DataFrame(key) = self.in_flight_data_frame {
if stream.key() == key {
// This stream could get cleaned up now - don't allow the buffered frame to get reclaimed.
self.in_flight_data_frame = InFlightData::Drop;
}
}
}

fn pop_frame<B>(
Expand Down
67 changes: 67 additions & 0 deletions tests/stream_states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,3 +878,70 @@ fn rst_while_closing() {

client.join(srv).wait().expect("wait");
}

#[test]
fn rst_with_buffered_data() {
use futures::future::lazy;

// Data is buffered in `FramedWrite` and the stream is reset locally before
// the data is fully flushed. Given that resetting a stream requires
// clearing all associated state for that stream, this test ensures that the
// buffered up frame is correctly handled.
let _ = ::env_logger::try_init();

// This allows the settings + headers frame through
let (io, srv) = mock::new_with_write_capacity(73);

// Synchronize the client / server on response
let (tx, rx) = ::futures::sync::oneshot::channel();

let srv = srv.assert_client_handshake()
.unwrap()
.recv_settings()
.recv_frame(
frames::headers(1)
.request("POST", "https://example.com/")
)
.buffer_bytes(128)
.send_frame(frames::headers(1).response(204).eos())
.send_frame(frames::reset(1).cancel())
.wait_for(rx)
.unbounded_bytes()
.recv_frame(
frames::data(1, vec![0; 16_384]))
.close()
;

// A large body
let body = vec![0; 2 * frame::DEFAULT_INITIAL_WINDOW_SIZE as usize];

let client = client::handshake(io)
.expect("handshake")
.and_then(|(mut client, conn)| {
let request = Request::builder()
.method(Method::POST)
.uri("https://example.com/")
.body(())
.unwrap();

// Send the request
let (resp, mut stream) = client.send_request(request, false)
.expect("send_request");

// Send the data
stream.send_data(body.into(), true).unwrap();

conn.drive({
resp.then(|res| {
Ok::<_, ()>(())
})
})
})
.and_then(move |(conn, _)| {
tx.send(()).unwrap();
conn.unwrap()
});


client.join(srv).wait().expect("wait");
}
100 changes: 98 additions & 2 deletions tests/support/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use futures::task::{self, Task};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_io::io::read_exact;

use std::{cmp, fmt, io};
use std::{cmp, fmt, io, usize};
use std::io::ErrorKind::WouldBlock;
use std::sync::{Arc, Mutex};

Expand All @@ -32,22 +32,44 @@ pub struct Pipe {

#[derive(Debug)]
struct Inner {
/// Data written by the test case to the h2 lib.
rx: Vec<u8>,

/// Notify when data is ready to be received.
rx_task: Option<Task>,

/// Data written by the `h2` library to be read by the test case.
tx: Vec<u8>,

/// Notify when data is written. This notifies the test case waiters.
tx_task: Option<Task>,

/// Number of bytes that can be written before `write` returns `NotReady`.
tx_rem: usize,

/// Task to notify when write capacity becomes available.
tx_rem_task: Option<Task>,

/// True when the pipe is closed.
closed: bool,
}

const PREFACE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";

/// Create a new mock and handle
pub fn new() -> (Mock, Handle) {
new_with_write_capacity(usize::MAX)
}

/// Create a new mock and handle allowing up to `cap` bytes to be written.
pub fn new_with_write_capacity(cap: usize) -> (Mock, Handle) {
let inner = Arc::new(Mutex::new(Inner {
rx: vec![],
rx_task: None,
tx: vec![],
tx_task: None,
tx_rem: cap,
tx_rem_task: None,
closed: false,
}));

Expand Down Expand Up @@ -303,14 +325,24 @@ impl io::Read for Mock {
impl AsyncRead for Mock {}

impl io::Write for Mock {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
let mut me = self.pipe.inner.lock().unwrap();

if me.closed {
return Err(io::Error::new(io::ErrorKind::BrokenPipe, "mock closed"));
}

if me.tx_rem == 0 {
me.tx_rem_task = Some(task::current());
return Err(io::ErrorKind::WouldBlock.into());
}

if buf.len() > me.tx_rem {
buf = &buf[..me.tx_rem];
}

me.tx.extend(buf);
me.tx_rem -= buf.len();

if let Some(task) = me.tx_task.take() {
task.notify();
Expand Down Expand Up @@ -477,6 +509,70 @@ pub trait HandleFutureExt {
}))
}

fn buffer_bytes(self, num: usize) -> Box<Future<Item = Handle, Error = Self::Error>>
where Self: Sized + 'static,
Self: Future<Item = Handle>,
Self::Error: fmt::Debug,
{
use futures::future::poll_fn;

Box::new(self.and_then(move |mut handle| {
// Set tx_rem to num
{
let mut i = handle.codec.get_mut().inner.lock().unwrap();
i.tx_rem = num;
}

let mut handle = Some(handle);

poll_fn(move || {
{
let mut inner = handle.as_mut().unwrap()
.codec.get_mut().inner.lock().unwrap();

if inner.tx_rem == 0 {
inner.tx_rem = usize::MAX;
} else {
inner.tx_task = Some(task::current());
return Ok(Async::NotReady);
}
}

Ok(handle.take().unwrap().into())
})
}))
}

fn unbounded_bytes(self) -> Box<Future<Item = Handle, Error = Self::Error>>
where Self: Sized + 'static,
Self: Future<Item = Handle>,
Self::Error: fmt::Debug,
{
Box::new(self.and_then(|mut handle| {
{
let mut i = handle.codec.get_mut().inner.lock().unwrap();
i.tx_rem = usize::MAX;

if let Some(task) = i.tx_rem_task.take() {
task.notify();
}
}

Ok(handle.into())
}))
}

fn then_notify(self, tx: oneshot::Sender<()>) -> Box<Future<Item = Handle, Error = Self::Error>>
where Self: Sized + 'static,
Self: Future<Item = Handle>,
Self::Error: fmt::Debug,
{
Box::new(self.map(move |handle| {
tx.send(()).unwrap();
handle
}))
}

fn wait_for<F>(self, other: F) -> Box<Future<Item = Self::Item, Error = Self::Error>>
where
F: Future + 'static,
Expand Down