Skip to content

Commit

Permalink
EncoderWriter: no longer adhere to ‘at most one write’
Browse files Browse the repository at this point in the history
The wording around Write::write method is changing with requirement
that it maps to ‘at most one write’ being removed¹.  With that, change
EncoderWriter::write so that it flushes entire output buffer at the
beginning and then proceeds to process new input.  This eliminates
returning Ok(0) which is effectively an error.

Also, change accounting for the occupied portion of the output buffer.
Rather than just having occupied length, track occupied range which
means moving data to front is no longer necessary.

¹ rust-lang/rust#107200
  • Loading branch information
mina86 committed Feb 1, 2023
1 parent 92e94d2 commit cbb0cfc
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 124 deletions.
186 changes: 76 additions & 110 deletions src/write/encoder.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use crate::engine::Engine;
use std::{
cmp, fmt, io,
io::{ErrorKind, Result},
};
use std::{cmp, fmt, io};

pub(crate) const BUF_SIZE: usize = 1024;
/// The most bytes whose encoding will fit in `BUF_SIZE`
Expand Down Expand Up @@ -74,21 +71,24 @@ pub struct EncoderWriter<'e, E: Engine, W: io::Write> {
/// Buffer to encode into. May hold leftover encoded bytes from a previous write call that the underlying writer
/// did not write last time.
output: [u8; BUF_SIZE],
/// How much of `output` is occupied with encoded data that couldn't be written last time
output_occupied_len: usize,
/// Occupied portion of output.
output_range: std::ops::Range<usize>,
/// panic safety: don't write again in destructor if writer panicked while we were writing to it
panicked: bool,
}

impl<'e, E: Engine, W: io::Write> fmt::Debug for EncoderWriter<'e, E, W> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let range = self.output_range.clone();
let truncated_len = range.len().min(5);
let truncated_range = range.start..range.start + truncated_len;
write!(
f,
"extra_input: {:?} extra_input_occupied_len:{:?} output[..5]: {:?} output_occupied_len: {:?}",
self.extra_input,
self.extra_input_occupied_len,
&self.output[0..5],
self.output_occupied_len
"extra_input: {:?} occupied output[..{}]: {:?} output_range: {:?}",
&self.extra_input[..self.extra_input_occupied_len],
truncated_len,
&self.output[truncated_range],
range,
)
}
}
Expand All @@ -102,7 +102,7 @@ impl<'e, E: Engine, W: io::Write> EncoderWriter<'e, E, W> {
extra_input: [0u8; MIN_ENCODE_CHUNK_SIZE],
extra_input_occupied_len: 0,
output: [0u8; BUF_SIZE],
output_occupied_len: 0,
output_range: 0..0,
panicked: false,
}
}
Expand All @@ -123,7 +123,7 @@ impl<'e, E: Engine, W: io::Write> EncoderWriter<'e, E, W> {
/// # Errors
///
/// The first error that is not of `ErrorKind::Interrupted` will be returned.
pub fn finish(&mut self) -> Result<W> {
pub fn finish(&mut self) -> io::Result<W> {
// If we could consume self in finish(), we wouldn't have to worry about this case, but
// finish() is retryable in the face of I/O errors, so we can't consume here.
if self.delegate.is_none() {
Expand All @@ -138,91 +138,72 @@ impl<'e, E: Engine, W: io::Write> EncoderWriter<'e, E, W> {
}

/// Write any remaining buffered data to the delegate writer.
fn write_final_leftovers(&mut self) -> Result<()> {
fn write_final_leftovers(&mut self) -> io::Result<()> {
if self.delegate.is_none() {
// finish() has already successfully called this, and we are now in drop() with a None
// writer, so just no-op
return Ok(());
}

self.write_all_encoded_output()?;

if self.extra_input_occupied_len > 0 {
// Make sure output isn’t full so we can append to it.
if self.output_range.end == self.output.len() {
self.flush_output()?;
}

let encoded_len = self
.engine
.encode_slice(
&self.extra_input[..self.extra_input_occupied_len],
&mut self.output[..],
&mut self.output[self.output_range.end..],
)
.expect("buffer is large enough");

self.output_occupied_len = encoded_len;

self.write_all_encoded_output()?;

// write succeeded, do not write the encoding of extra again if finish() is retried
self.output_range.end += encoded_len;
self.extra_input_occupied_len = 0;
}

Ok(())
self.flush_output()
}

/// Write as much of the encoded output to the delegate writer as it will accept, and store the
/// leftovers to be attempted at the next write() call. Updates `self.output_occupied_len`.
/// Flushes output buffer to the delegate.
///
/// # Errors
///
/// Errors from the delegate writer are returned. In the case of an error,
/// `self.output_occupied_len` will not be updated, as errors from `write` are specified to mean
/// that no write took place.
fn write_to_delegate(&mut self, current_output_len: usize) -> Result<()> {
/// Loops writing data to the delegate until output buffer is empty or
/// delegate returns an error. A `Ok(0)` return from the delegate is
/// treated as an error. Updates `output_range` accordingly.
fn flush_output(&mut self) -> io::Result<()> {
if self.output_range.is_empty() {
return Ok(());
}
loop {
match self.write_to_delegate(self.output_range.clone()) {
Ok(0) => {
break Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write whole buffer",
))
}
Ok(n) if n == self.output_range.len() => {
self.output_range = 0..0;
break Ok(());
}
Ok(n) => self.output_range.start += n,
Err(err) => break Err(err),
}
}
}

/// Writes given range of output buffer to the delegate. Performs exactly
/// one write. Sets `panicked` to `true` if delegate panics.
fn write_to_delegate(&mut self, range: std::ops::Range<usize>) -> io::Result<usize> {
self.panicked = true;
let res = self
.delegate
.as_mut()
.expect("Writer must be present")
.write(&self.output[..current_output_len]);
.expect("Encoder has already had finish() called")
.write(&self.output[range]);
self.panicked = false;

res.map(|consumed| {
debug_assert!(consumed <= current_output_len);

if consumed < current_output_len {
self.output_occupied_len = current_output_len.checked_sub(consumed).unwrap();
// If we're blocking on I/O, the minor inefficiency of copying bytes to the
// start of the buffer is the least of our concerns...
// TODO Rotate moves more than we need to; copy_within now stable.
self.output.rotate_left(consumed);
} else {
self.output_occupied_len = 0;
}
})
}

/// Write all buffered encoded output. If this returns `Ok`, `self.output_occupied_len` is `0`.
///
/// This is basically write_all for the remaining buffered data but without the undesirable
/// abort-on-`Ok(0)` behavior.
///
/// # Errors
///
/// Any error emitted by the delegate writer abort the write loop and is returned, unless it's
/// `Interrupted`, in which case the error is ignored and writes will continue.
fn write_all_encoded_output(&mut self) -> Result<()> {
while self.output_occupied_len > 0 {
let remaining_len = self.output_occupied_len;
match self.write_to_delegate(remaining_len) {
// try again on interrupts ala write_all
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
// other errors return
Err(e) => return Err(e),
// success no-ops because remaining length is already updated
Ok(_) => {}
};
}

debug_assert_eq!(0, self.output_occupied_len);
Ok(())
res
}

/// Unwraps this `EncoderWriter`, returning the base writer it writes base64 encoded output
Expand Down Expand Up @@ -262,38 +243,22 @@ impl<'e, E: Engine, W: io::Write> io::Write for EncoderWriter<'e, E, W> {
/// # Errors
///
/// Any errors emitted by the delegate writer are returned.
fn write(&mut self, input: &[u8]) -> Result<usize> {
fn write(&mut self, input: &[u8]) -> io::Result<usize> {
if self.delegate.is_none() {
panic!("Cannot write more after calling finish()");
}

self.flush_output()?;
debug_assert_eq!(0, self.output_range.len());

if input.is_empty() {
return Ok(0);
}

// The contract of `Write::write` places some constraints on this implementation:
// - a call to `write()` represents at most one call to a wrapped `Write`, so we can't
// iterate over the input and encode multiple chunks.
// - Errors mean that "no bytes were written to this writer", so we need to reset the
// internal state to what it was before the error occurred

// before reading any input, write any leftover encoded output from last time
if self.output_occupied_len > 0 {
let current_len = self.output_occupied_len;
return self
.write_to_delegate(current_len)
// did not read any input
.map(|_| 0);
}

debug_assert_eq!(0, self.output_occupied_len);

// how many bytes, if any, were read into `extra` to create a triple to encode
let mut extra_input_read_len = 0;
let mut input = input;

let orig_extra_len = self.extra_input_occupied_len;

let mut encoded_size = 0;
// always a multiple of MIN_ENCODE_CHUNK_SIZE
let mut max_input_len = MAX_INPUT_LEN;
Expand Down Expand Up @@ -322,8 +287,10 @@ impl<'e, E: Engine, W: io::Write> io::Write for EncoderWriter<'e, E, W> {

input = &input[extra_input_read_len..];

// consider extra to be used up, since we encoded it
self.extra_input_occupied_len = 0;
// Note: Not updating self.extra_input_occupied_len yet. It’s
// going to be zeroed at the end of the function if we
// successfully write some data to delegate.

// don't clobber where we just encoded to
encoded_size = 4;
// and don't read more than can be encoded
Expand Down Expand Up @@ -367,29 +334,28 @@ impl<'e, E: Engine, W: io::Write> io::Write for EncoderWriter<'e, E, W> {
&mut self.output[encoded_size..],
);

// not updating `self.output_occupied_len` here because if the below write fails, it should
// "never take place" -- the buffer contents we encoded are ignored and perhaps retried
// later, if the consumer chooses.
// Not updating `self.output_range` here because if the write fails, it
// should "never take place" -- the buffer contents we encoded are
// ignored and perhaps retried later, if the consumer chooses.

self.write_to_delegate(encoded_size)
// no matter whether we wrote the full encoded buffer or not, we consumed the same
// input
.map(|_| extra_input_read_len + input_chunks_to_encode_len)
.map_err(|e| {
// in case we filled and encoded `extra`, reset extra_len
self.extra_input_occupied_len = orig_extra_len;

e
})
self.write_to_delegate(0..encoded_size).map(|written| {
if written < encoded_size {
self.output_range = written..encoded_size;
} else {
debug_assert_eq!(0, self.output_range.len());
}
self.extra_input_occupied_len = 0;
extra_input_read_len + input_chunks_to_encode_len
})
}

/// Because this is usually treated as OK to call multiple times, it will *not* flush any
/// incomplete chunks of input or write padding.
/// # Errors
///
/// The first error that is not of [`ErrorKind::Interrupted`] will be returned.
fn flush(&mut self) -> Result<()> {
self.write_all_encoded_output()?;
fn flush(&mut self) -> io::Result<()> {
self.flush_output()?;
self.delegate
.as_mut()
.expect("Writer must be present")
Expand Down
36 changes: 22 additions & 14 deletions src/write/encoder_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,17 +412,21 @@ fn writes_that_only_write_part_of_input_and_sometimes_interrupt_produce_correct_

// retry on interrupt
match res {
Ok(len) => bytes_consumed += len,
Err(e) => match e.kind() {
io::ErrorKind::Interrupted => continue,
_ => {
panic!("should not see other errors");
}
},
Ok(0) => assert_eq!(0, input_len),
Ok(len) => {
assert!(len <= input_len);
bytes_consumed += len;
}
Err(e) => assert_eq!(io::ErrorKind::Interrupted, e.kind()),
}
}

let _ = stream_encoder.finish().unwrap();
loop {
match stream_encoder.finish() {
Ok(_) => break,
Err(e) => assert_eq!(io::ErrorKind::Interrupted, e.kind()),
}
}

assert_eq!(orig_len, bytes_consumed);
}
Expand Down Expand Up @@ -506,15 +510,15 @@ struct InterruptingWriter<'a, W: 'a + Write, R: 'a + Rng> {

impl<'a, W: Write, R: Rng> Write for InterruptingWriter<'a, W, R> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.rng.gen_range(0.0..1.0) <= self.fraction {
if self.rng.gen_bool(self.fraction) {
return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
}

self.w.write(buf)
}

fn flush(&mut self) -> io::Result<()> {
if self.rng.gen_range(0.0..1.0) <= self.fraction {
if self.rng.gen_bool(self.fraction) {
return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
}

Expand All @@ -534,17 +538,21 @@ struct PartialInterruptingWriter<'a, W: 'a + Write, R: 'a + Rng> {

impl<'a, W: Write, R: Rng> Write for PartialInterruptingWriter<'a, W, R> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.rng.gen_range(0.0..1.0) > self.no_interrupt_fraction {
if !self.rng.gen_bool(self.no_interrupt_fraction) {
return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
}

if self.rng.gen_range(0.0..1.0) <= self.full_input_fraction || buf.is_empty() {
if buf.len() <= 1 || self.rng.gen_bool(self.full_input_fraction) {
// pass through the buf untouched
self.w.write(buf)
} else {
// only use a prefix of it
self.w
.write(&buf[0..(self.rng.gen_range(0..(buf.len() - 1)))])
let end = if buf.len() == 2 {
1
} else {
self.rng.gen_range(1..(buf.len() - 1))
};
self.w.write(&buf[..end])
}
}

Expand Down

0 comments on commit cbb0cfc

Please sign in to comment.